From 5f9a966bb523bf61daefcff209199bc774fa5ed6 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Fri, 15 Aug 2025 20:34:51 -0300 Subject: [PATCH 01/33] fix: ensure server-initiated requests include a valid request_id (#80) --- .../src/mcp_handlers/mcp_server_handler.rs | 1 + .../src/mcp_runtimes/server_runtime.rs | 91 ++++++++++++++----- .../server_runtime/mcp_server_runtime_core.rs | 1 + .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 4 +- 4 files changed, 73 insertions(+), 24 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 5b0fdc0..bf3fe17 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -51,6 +51,7 @@ pub trait ServerHandler: Send + Sync + 'static { runtime .set_client_details(initialize_request.params.clone()) + .await .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; Ok(server_info) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 28cdd8c..d1a8a26 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -1,30 +1,27 @@ pub mod mcp_server_runtime; pub mod mcp_server_runtime_core; +use crate::error::SdkResult; +use crate::mcp_traits::mcp_handler::McpServerHandler; +use crate::mcp_traits::mcp_server::McpServer; use crate::schema::{ schema_utils::{ - ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, - ServerMessages, + ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromServer, SdkError, + ServerMessage, ServerMessages, }, InitializeRequestParams, InitializeResult, RequestId, RpcError, }; - use async_trait::async_trait; use futures::future::try_join_all; use futures::{StreamExt, TryFutureExt}; - +#[cfg(feature = "hyper-server")] +use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; - use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, watch}; -use crate::error::SdkResult; -use crate::mcp_traits::mcp_handler::McpServerHandler; -use crate::mcp_traits::mcp_server::McpServer; -#[cfg(feature = "hyper-server")] -use rust_mcp_transport::SessionId; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; // Define a type alias for the TransportDispatcher trait object @@ -49,21 +46,32 @@ pub struct ServerRuntime { #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, + client_details_tx: watch::Sender>, + client_details_rx: watch::Receiver>, } #[async_trait] impl McpServer for ServerRuntime { /// Set the client details, storing them in client_details - fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { - match self.client_details.write() { - Ok(mut details) => { - *details = Some(client_details); - Ok(()) + async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { + self.handler.on_server_started(self).await; + + self.client_details_tx + .send(Some(client_details)) + .map_err(|_| { + RpcError::internal_error() + .with_message("Failed to set client details".to_string()) + .into() + }) + } + + async fn wait_for_initialization(&self) { + loop { + if self.client_details_rx.borrow().is_some() { + return; } - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - Err(_) => Err(RpcError::internal_error() - .with_message("Internal Error: Failed to acquire write lock.".to_string()) - .into()), + let mut rx = self.client_details_rx.clone(); + rx.changed().await.ok(); } } @@ -79,7 +87,19 @@ impl McpServer for ServerRuntime { .with_message("transport stream does not exists or is closed!".to_string()), )?; - let mcp_message = ServerMessage::from_message(message, request_id)?; + // generate a new request_id for request messages + let outgoing_request_id = if message.is_request() { + match request_id { + Some(_) => Err(RpcError::internal_error().with_message( + "request_id should not have a value when sending a new request".to_string(), + )), + None => Ok(self.next_request_id(transport).await), + } + } else { + Ok(request_id) + }?; + + let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; transport .send_message(ServerMessages::Single(mcp_message), request_timeout) .map_err(|err| err.into()) @@ -130,8 +150,6 @@ impl McpServer for ServerRuntime { let mut stream = transport.start().await?; - self.handler.on_server_started(self).await; - // Process incoming messages from the client while let Some(mcp_messages) = stream.next().await { match mcp_messages { @@ -207,6 +225,25 @@ impl ServerRuntime { Ok(()) } + pub(crate) async fn next_request_id( + &self, + transport: &Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, + >, + ) -> Option { + let message_sender = transport.message_sender(); + let guard = message_sender.read().await; + guard + .as_ref() + .map(|dispatcher| dispatcher.next_request_id()) + } + pub(crate) async fn handle_message( &self, message: ClientMessage, @@ -416,12 +453,16 @@ impl ServerRuntime { handler: Arc, session_id: SessionId, ) -> Self { + let (client_details_tx, client_details_rx) = + watch::channel::>(None); Self { server_details, client_details: Arc::new(RwLock::new(None)), handler, session_id: Some(session_id), transport_map: tokio::sync::RwLock::new(HashMap::new()), + client_details_tx, + client_details_rx, } } @@ -438,6 +479,8 @@ impl ServerRuntime { ) -> Self { let mut map: HashMap = HashMap::new(); map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); + let (client_details_tx, client_details_rx) = + watch::channel::>(None); Self { server_details: Arc::new(server_details), client_details: Arc::new(RwLock::new(None)), @@ -445,6 +488,8 @@ impl ServerRuntime { #[cfg(feature = "hyper-server")] session_id: None, transport_map: tokio::sync::RwLock::new(map), + client_details_tx, + client_details_rx, } } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 27f04df..154b4bc 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -76,6 +76,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> // keep a copy of the InitializeRequestParams which includes client_info and capabilities runtime .set_client_details(initialize_request.params.clone()) + .await .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index cf0f168..a1d501d 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -23,10 +23,12 @@ use crate::{error::SdkResult, utils::format_assertion_message}; #[async_trait] pub trait McpServer: Sync + Send { async fn start(&self) -> SdkResult<()>; - fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; + async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; + async fn wait_for_initialization(&self); + #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] fn get_client_info(&self) -> Option { self.client_info() From 1ca8e49860e990c3562623e75dd723b0d1dc8256 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Sat, 16 Aug 2025 17:10:57 -0300 Subject: [PATCH 02/33] fix: abort keep-alive task when transport is removed (#82) --- crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d1a8a26..69e8b88 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -392,7 +392,10 @@ impl ServerRuntime { let transport = self.transport_by_stream(stream_id).await?; let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); - let _ = transport.keep_alive(ping_interval, disconnect_tx).await; + let abort_alive_task = transport + .keep_alive(ping_interval, disconnect_tx) + .await? + .abort_handle(); // in case there is a payload, we consume it by transport to get processed if let Some(payload) = payload { @@ -429,11 +432,13 @@ impl ServerRuntime { } // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ + abort_alive_task.abort(); return Ok(()); } } _ = &mut disconnect_rx => { self.remove_transport(stream_id).await?; + abort_alive_task.abort(); // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); From 36dfa4cdc821e958ffe78b909ed28f5577d113c8 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 16 Aug 2025 19:48:36 -0300 Subject: [PATCH 03/33] feat: integrate list root and client info into hyper runtime --- .../src/hyper_servers/hyper_runtime.rs | 27 +++++- .../src/mcp_runtimes/server_runtime.rs | 37 ++++---- .../tests/test_streamable_http.rs | 87 +++++++++++++++++-- 3 files changed, 129 insertions(+), 22 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 30df951..109dde5 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -4,7 +4,8 @@ use crate::{ mcp_server::HyperServer, schema::{ schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, - CreateMessageRequestParams, CreateMessageResult, LoggingMessageNotificationParams, + CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams, + ListRootsRequestParams, ListRootsResult, LoggingMessageNotificationParams, PromptListChangedNotificationParams, ResourceListChangedNotificationParams, ResourceUpdatedNotificationParams, ToolListChangedNotificationParams, }, @@ -99,6 +100,21 @@ impl HyperRuntime { runtime.send_notification(notification).await } + /// Request a list of root URIs from the client. Roots allow + /// servers to ask for specific directories or files to operate on. A common example + /// for roots is providing a set of repositories or directories a server should operate on. + /// This request is typically used when the server needs to understand the file system + /// structure or access specific locations that the client has permission to read from + pub async fn list_roots( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.list_roots(params).await + } + pub async fn send_logging_message( &self, session_id: &SessionId, @@ -195,4 +211,13 @@ impl HyperRuntime { let runtime = runtime.lock().await.to_owned(); runtime.create_message(params).await } + + pub async fn client_info( + &self, + session_id: &SessionId, + ) -> SdkResult> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + Ok(runtime.client_info()) + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d1a8a26..da6150a 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -87,17 +87,9 @@ impl McpServer for ServerRuntime { .with_message("transport stream does not exists or is closed!".to_string()), )?; - // generate a new request_id for request messages - let outgoing_request_id = if message.is_request() { - match request_id { - Some(_) => Err(RpcError::internal_error().with_message( - "request_id should not have a value when sending a new request".to_string(), - )), - None => Ok(self.next_request_id(transport).await), - } - } else { - Ok(request_id) - }?; + let outgoing_request_id = self + .request_id_for_message(transport, &message, request_id) + .await; let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; transport @@ -225,7 +217,18 @@ impl ServerRuntime { Ok(()) } - pub(crate) async fn next_request_id( + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + pub(crate) async fn request_id_for_message( &self, transport: &Arc< dyn TransportDispatcher< @@ -236,12 +239,16 @@ impl ServerRuntime { ServerMessage, >, >, + message: &MessageFromServer, + request_id: Option, ) -> Option { let message_sender = transport.message_sender(); let guard = message_sender.read().await; - guard - .as_ref() - .map(|dispatcher| dispatcher.next_request_id()) + if let Some(dispatcher) = guard.as_ref() { + dispatcher.request_id_for_message(message, request_id) + } else { + None + } } pub(crate) async fn handle_message( diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http.rs index 08c85e8..5eb5e47 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http.rs @@ -3,13 +3,14 @@ use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, vec}; use hyper::StatusCode; use rust_mcp_schema::{ schema_utils::{ - ClientJsonrpcRequest, ClientMessage, ClientMessages, FromMessage, NotificationFromServer, - ResultFromServer, RpcMessage, SdkError, SdkErrorCodes, ServerJsonrpcNotification, - ServerJsonrpcResponse, ServerMessages, + ClientJsonrpcRequest, ClientJsonrpcResponse, ClientMessage, ClientMessages, FromMessage, + NotificationFromServer, RequestFromServer, ResultFromServer, RpcMessage, SdkError, + SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, + ServerMessages, }, - CallToolRequest, CallToolRequestParams, ListToolsRequest, LoggingLevel, - LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, - ServerResult, + CallToolRequest, CallToolRequestParams, ListPromptsRequestParams, ListRootsRequestParams, + ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId, + RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, }; use rust_mcp_sdk::mcp_server::HyperServerOptions; use serde_json::{json, Map, Value}; @@ -364,6 +365,80 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { server.hyper_runtime.await_server().await.unwrap() } +// should establish standalone SSE stream and receive server-initiated messages +#[tokio::test] +async fn should_establish_standalone_stream_and_receive_server_requests() { + let (server, session_id) = initialize_server(None).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id).await; + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!( + response + .headers() + .get("mcp-session-id") + .unwrap() + .to_str() + .unwrap(), + session_id + ); + + assert_eq!( + response + .headers() + .get("content-type") + .unwrap() + .to_str() + .unwrap(), + "text/event-stream" + ); + + let hyper_server = Arc::new(server.hyper_runtime); + let hyper_server_clone = hyper_server.clone(); + let session_id_clone = session_id.to_string(); + + tokio::spawn(async move { + // Send a server-initiated notification that should appear on SSE stream with a valid request_id + hyper_server_clone + .list_roots(&session_id_clone, None) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(2250)).await; + + let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( + RequestId::Integer(0), + ListRootsResult { + meta: None, + roots: vec![], + } + .into(), + ); + + send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + let event = read_sse_event(response).await.unwrap(); + + let message: ServerJsonrpcRequest = serde_json::from_str(&event).unwrap(); + + println!(">>> message {:?} ", message); + + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message.request + else { + panic!("invalid message received!"); + }; + + hyper_server.graceful_shutdown(ONE_MILLISECOND); +} + // should not close GET SSE stream after sending multiple server notifications #[tokio::test] async fn should_not_close_get_sse_stream() { From 6995d25ca43dda30d679b187fe150768be8acecf Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 16 Aug 2025 19:53:11 -0300 Subject: [PATCH 04/33] chore: remove unused import --- crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 368955d..d7b53a1 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -5,8 +5,8 @@ use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; use crate::schema::{ schema_utils::{ - ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromServer, SdkError, - ServerMessage, ServerMessages, + ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, + ServerMessages, }, InitializeRequestParams, InitializeResult, RequestId, RpcError, }; From aacdbfebd063f8267b5faf9437e68da7d4419aa7 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Sat, 16 Aug 2025 20:11:00 -0300 Subject: [PATCH 05/33] chore: release main (#81) --- .release-manifest.json | 18 +++++++++--------- Cargo.lock | 18 +++++++++--------- crates/rust-mcp-sdk/CHANGELOG.md | 13 +++++++++++++ crates/rust-mcp-sdk/Cargo.toml | 2 +- .../hello-world-mcp-server-core/Cargo.toml | 2 +- examples/hello-world-mcp-server/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- examples/simple-mcp-client-core-sse/Cargo.toml | 2 +- examples/simple-mcp-client-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- examples/simple-mcp-client/Cargo.toml | 2 +- 12 files changed, 40 insertions(+), 27 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index e8ad288..36d8135 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.5.1", + "crates/rust-mcp-sdk": "0.5.2", "crates/rust-mcp-macros": "0.5.1", "crates/rust-mcp-transport": "0.4.1", - "examples/hello-world-mcp-server": "0.1.25", - "examples/hello-world-mcp-server-core": "0.1.16", - "examples/simple-mcp-client": "0.1.25", - "examples/simple-mcp-client-core": "0.1.25", - "examples/hello-world-server-core-streamable-http": "0.1.16", - "examples/hello-world-server-streamable-http": "0.1.25", - "examples/simple-mcp-client-core-sse": "0.1.16", - "examples/simple-mcp-client-sse": "0.1.16" + "examples/hello-world-mcp-server": "0.1.26", + "examples/hello-world-mcp-server-core": "0.1.17", + "examples/simple-mcp-client": "0.1.26", + "examples/simple-mcp-client-core": "0.1.26", + "examples/hello-world-server-core-streamable-http": "0.1.17", + "examples/hello-world-server-streamable-http": "0.1.26", + "examples/simple-mcp-client-core-sse": "0.1.17", + "examples/simple-mcp-client-sse": "0.1.17" } diff --git a/Cargo.lock b/Cargo.lock index df081df..23688b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -688,7 +688,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server" -version = "0.1.25" +version = "0.1.26" dependencies = [ "async-trait", "futures", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-core" -version = "0.1.16" +version = "0.1.17" dependencies = [ "async-trait", "futures", @@ -714,7 +714,7 @@ dependencies = [ [[package]] name = "hello-world-server-core-streamable-http" -version = "0.1.16" +version = "0.1.17" dependencies = [ "async-trait", "futures", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.25" +version = "0.1.26" dependencies = [ "async-trait", "futures", @@ -1699,7 +1699,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.5.1" +version = "0.5.2" dependencies = [ "async-trait", "axum", @@ -1924,7 +1924,7 @@ dependencies = [ [[package]] name = "simple-mcp-client" -version = "0.1.25" +version = "0.1.26" dependencies = [ "async-trait", "colored", @@ -1938,7 +1938,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core" -version = "0.1.25" +version = "0.1.26" dependencies = [ "async-trait", "colored", @@ -1952,7 +1952,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core-sse" -version = "0.1.16" +version = "0.1.17" dependencies = [ "async-trait", "colored", @@ -1968,7 +1968,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.16" +version = "0.1.17" dependencies = [ "async-trait", "colored", diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index 8f2f4f7..5588727 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## [0.5.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.1...rust-mcp-sdk-v0.5.2) (2025-08-16) + + +### πŸš€ Features + +* Integrate list root and client info into hyper runtime ([36dfa4c](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/36dfa4cdc821e958ffe78b909ed28f5577d113c8)) + + +### πŸ› Bug Fixes + +* Abort keep-alive task when transport is removed ([#82](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/82)) ([1ca8e49](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/1ca8e49860e990c3562623e75dd723b0d1dc8256)) +* Ensure server-initiated requests include a valid request_id ([#80](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/80)) ([5f9a966](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/5f9a966bb523bf61daefcff209199bc774fa5ed6)) + ## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.0...rust-mcp-sdk-v0.5.1) (2025-08-12) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 5f28fa3..97a3e8b 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.5.1" +version = "0.5.2" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-core/Cargo.toml index a38a0b9..6b053b4 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-core" -version = "0.1.16" +version = "0.1.17" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server/Cargo.toml index 7fc7d0f..19dad29 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server" -version = "0.1.25" +version = "0.1.26" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml index 84dfd70..a9883cb 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-core-streamable-http" -version = "0.1.16" +version = "0.1.17" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 6776b0c..a5e975e 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.25" +version = "0.1.26" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-core-sse/Cargo.toml index 3cbd9df..d852695 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-core-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core-sse" -version = "0.1.16" +version = "0.1.17" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-core/Cargo.toml index d0288f9..db3282b 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core" -version = "0.1.25" +version = "0.1.26" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 60dd69c..5bad697 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.16" +version = "0.1.17" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client/Cargo.toml index cdfa228..c21c893 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client" -version = "0.1.25" +version = "0.1.26" edition = "2021" publish = false license = "MIT" From 308b1dbd1744ff06046902303d8bcd6c3a92ffbe Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Mon, 18 Aug 2025 21:32:44 -0300 Subject: [PATCH 06/33] fix: handle missing client details and abort keep-alive task on drop (#83) - Added guard (AbortTaskOnDrop) to ensure keep-alive task is aborted when no longer needed - Fixed bug where client_info was mistakenly returning None --- .../src/mcp_runtimes/server_runtime.rs | 21 +++++++------------ crates/rust-mcp-sdk/src/utils.rs | 17 +++++++++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d7b53a1..d787a10 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -10,6 +10,7 @@ use crate::schema::{ }, InitializeRequestParams, InitializeResult, RequestId, RpcError, }; +use crate::utils::AbortTaskOnDrop; use async_trait::async_trait; use futures::future::try_join_all; use futures::{StreamExt, TryFutureExt}; @@ -17,7 +18,7 @@ use futures::{StreamExt, TryFutureExt}; use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::sync::{oneshot, watch}; @@ -41,8 +42,6 @@ pub struct ServerRuntime { handler: Arc, // Information about the server server_details: Arc, - // Details about the connected client - client_details: Arc>>, #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, @@ -123,12 +122,7 @@ impl McpServer for ServerRuntime { /// Returns the client information if available, after successful initialization , otherwise returns None fn client_info(&self) -> Option { - if let Ok(details) = self.client_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None - } + self.client_details_rx.borrow().clone() } /// Main runtime loop, processes incoming messages and handles requests @@ -404,6 +398,11 @@ impl ServerRuntime { .await? .abort_handle(); + // ensure keep_alive task will be aborted + let _abort_guard = AbortTaskOnDrop { + handle: abort_alive_task, + }; + // in case there is a payload, we consume it by transport to get processed if let Some(payload) = payload { transport.consume_string_payload(&payload).await?; @@ -439,13 +438,11 @@ impl ServerRuntime { } // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ - abort_alive_task.abort(); return Ok(()); } } _ = &mut disconnect_rx => { self.remove_transport(stream_id).await?; - abort_alive_task.abort(); // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -469,7 +466,6 @@ impl ServerRuntime { watch::channel::>(None); Self { server_details, - client_details: Arc::new(RwLock::new(None)), handler, session_id: Some(session_id), transport_map: tokio::sync::RwLock::new(HashMap::new()), @@ -495,7 +491,6 @@ impl ServerRuntime { watch::channel::>(None); Self { server_details: Arc::new(server_details), - client_details: Arc::new(RwLock::new(None)), handler, #[cfg(feature = "hyper-server")] session_id: None, diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index de92a06..e98a1ed 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -4,6 +4,23 @@ use crate::error::{McpSdkError, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; +/// A guard type that automatically aborts a Tokio task when dropped. +/// +/// This ensures that the associated task does not outlive the scope +/// of this struct, preventing runaway or leaked background tasks. +/// +pub struct AbortTaskOnDrop { + /// The handle used to abort the spawned Tokio task. + pub handle: tokio::task::AbortHandle, +} + +impl Drop for AbortTaskOnDrop { + fn drop(&mut self) { + // Automatically abort the associated task when this guard is dropped. + self.handle.abort(); + } +} + /// Formats an assertion error message for unsupported capabilities. /// /// Constructs a string describing that a specific entity (e.g., server or client) lacks From 080b3a5360ad7459d933a5f9654b5ac0ac82a59d Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Mon, 18 Aug 2025 21:38:38 -0300 Subject: [PATCH 07/33] chore: release main (#84) --- .release-manifest.json | 18 +++++++++--------- Cargo.lock | 18 +++++++++--------- crates/rust-mcp-sdk/CHANGELOG.md | 7 +++++++ crates/rust-mcp-sdk/Cargo.toml | 2 +- .../hello-world-mcp-server-core/Cargo.toml | 2 +- examples/hello-world-mcp-server/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- examples/simple-mcp-client-core-sse/Cargo.toml | 2 +- examples/simple-mcp-client-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- examples/simple-mcp-client/Cargo.toml | 2 +- 12 files changed, 34 insertions(+), 27 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index 36d8135..a7e7c0e 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.5.2", + "crates/rust-mcp-sdk": "0.5.3", "crates/rust-mcp-macros": "0.5.1", "crates/rust-mcp-transport": "0.4.1", - "examples/hello-world-mcp-server": "0.1.26", - "examples/hello-world-mcp-server-core": "0.1.17", - "examples/simple-mcp-client": "0.1.26", - "examples/simple-mcp-client-core": "0.1.26", - "examples/hello-world-server-core-streamable-http": "0.1.17", - "examples/hello-world-server-streamable-http": "0.1.26", - "examples/simple-mcp-client-core-sse": "0.1.17", - "examples/simple-mcp-client-sse": "0.1.17" + "examples/hello-world-mcp-server": "0.1.27", + "examples/hello-world-mcp-server-core": "0.1.18", + "examples/simple-mcp-client": "0.1.27", + "examples/simple-mcp-client-core": "0.1.27", + "examples/hello-world-server-core-streamable-http": "0.1.18", + "examples/hello-world-server-streamable-http": "0.1.27", + "examples/simple-mcp-client-core-sse": "0.1.18", + "examples/simple-mcp-client-sse": "0.1.18" } diff --git a/Cargo.lock b/Cargo.lock index 23688b6..d51b2d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -688,7 +688,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server" -version = "0.1.26" +version = "0.1.27" dependencies = [ "async-trait", "futures", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-core" -version = "0.1.17" +version = "0.1.18" dependencies = [ "async-trait", "futures", @@ -714,7 +714,7 @@ dependencies = [ [[package]] name = "hello-world-server-core-streamable-http" -version = "0.1.17" +version = "0.1.18" dependencies = [ "async-trait", "futures", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.26" +version = "0.1.27" dependencies = [ "async-trait", "futures", @@ -1699,7 +1699,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.5.2" +version = "0.5.3" dependencies = [ "async-trait", "axum", @@ -1924,7 +1924,7 @@ dependencies = [ [[package]] name = "simple-mcp-client" -version = "0.1.26" +version = "0.1.27" dependencies = [ "async-trait", "colored", @@ -1938,7 +1938,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core" -version = "0.1.26" +version = "0.1.27" dependencies = [ "async-trait", "colored", @@ -1952,7 +1952,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core-sse" -version = "0.1.17" +version = "0.1.18" dependencies = [ "async-trait", "colored", @@ -1968,7 +1968,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.17" +version = "0.1.18" dependencies = [ "async-trait", "colored", diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index 5588727..720c438 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.5.3](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.2...rust-mcp-sdk-v0.5.3) (2025-08-19) + + +### πŸ› Bug Fixes + +* Handle missing client details and abort keep-alive task on drop ([#83](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/83)) ([308b1db](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/308b1dbd1744ff06046902303d8bcd6c3a92ffbe)) + ## [0.5.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.1...rust-mcp-sdk-v0.5.2) (2025-08-16) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 97a3e8b..6e05365 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.5.2" +version = "0.5.3" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-core/Cargo.toml index 6b053b4..b1256a5 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-core" -version = "0.1.17" +version = "0.1.18" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server/Cargo.toml index 19dad29..0f1b5d1 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server" -version = "0.1.26" +version = "0.1.27" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml index a9883cb..afc9c29 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-core-streamable-http" -version = "0.1.17" +version = "0.1.18" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index a5e975e..3abc10d 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.26" +version = "0.1.27" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-core-sse/Cargo.toml index d852695..d66a7cd 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-core-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core-sse" -version = "0.1.17" +version = "0.1.18" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-core/Cargo.toml index db3282b..9a9c439 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core" -version = "0.1.26" +version = "0.1.27" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 5bad697..3b60bc9 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.17" +version = "0.1.18" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client/Cargo.toml index c21c893..39c2bc5 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client" -version = "0.1.26" +version = "0.1.27" edition = "2021" publish = false license = "MIT" From 287b7138d883563f4d2491e4c64abed9804757bd Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 18:54:19 -0300 Subject: [PATCH 08/33] refactor: request-id generation and messaging functions --- .../src/mcp_runtimes/client_runtime.rs | 40 ++++++- .../src/mcp_runtimes/server_runtime.rs | 55 +++------ .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 43 ++----- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 10 +- crates/rust-mcp-sdk/tests/common/common.rs | 13 ++- .../tests/test_streamable_http.rs | 107 ++++++++++-------- crates/rust-mcp-transport/src/lib.rs | 2 + crates/rust-mcp-transport/src/mcp_stream.rs | 14 +-- .../src/message_dispatcher.rs | 40 ------- .../rust-mcp-transport/src/request_id_gen.rs | 99 ++++++++++++++++ 10 files changed, 233 insertions(+), 190 deletions(-) create mode 100644 crates/rust-mcp-transport/src/request_id_gen.rs diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 8d113c3..1bd1809 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -7,14 +7,19 @@ use crate::schema::{ ServerMessages, }, InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RpcError, ServerResult, + RequestId, RpcError, ServerResult, }; use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; -use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::sync::{Arc, RwLock}; +use rust_mcp_transport::{ + IoStream, McpDispatch, MessageDispatcher, RequestIdGen, RequestIdGenNumeric, Transport, +}; +use std::{ + sync::{Arc, RwLock}, + time::Duration, +}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::Mutex; @@ -41,6 +46,7 @@ pub struct ClientRuntime { // Details about the connected server server_details: Arc>>, handlers: Mutex>>>, + request_id_gen: Box, } impl ClientRuntime { @@ -61,6 +67,7 @@ impl ClientRuntime { client_details, server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -284,6 +291,33 @@ impl McpClient for ClientRuntime { } } + async fn send( + &self, + message: MessageFromClient, + request_id: Option, + timeout: Option, + ) -> SdkResult> { + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = sender + .send_message(ClientMessages::Single(mcp_message), timeout) + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) + } + async fn is_shut_down(&self) -> bool { self.transport.is_shut_down().await } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d787a10..a118685 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -16,7 +16,7 @@ use futures::future::try_join_all; use futures::{StreamExt, TryFutureExt}; #[cfg(feature = "hyper-server")] use rust_mcp_transport::SessionId; -use rust_mcp_transport::{IoStream, TransportDispatcher}; +use rust_mcp_transport::{IoStream, RequestIdGen, RequestIdGenNumeric, TransportDispatcher}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -45,6 +45,7 @@ pub struct ServerRuntime { #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, + request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, } @@ -79,7 +80,7 @@ impl McpServer for ServerRuntime { message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult> { + ) -> SdkResult> { let transport_map = self.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -87,14 +88,18 @@ impl McpServer for ServerRuntime { )?; let outgoing_request_id = self - .request_id_for_message(transport, &message, request_id) - .await; + .request_id_gen + .request_id_for_message(&message, request_id); let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; - transport + + let response = transport .send_message(ServerMessages::Single(mcp_message), request_timeout) - .map_err(|err| err.into()) - .await + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) } async fn send_batch( @@ -211,40 +216,6 @@ impl ServerRuntime { Ok(()) } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub(crate) async fn request_id_for_message( - &self, - transport: &Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - message: &MessageFromServer, - request_id: Option, - ) -> Option { - let message_sender = transport.message_sender(); - let guard = message_sender.read().await; - if let Some(dispatcher) = guard.as_ref() { - dispatcher.request_id_for_message(message, request_id) - } else { - None - } - } - pub(crate) async fn handle_message( &self, message: ClientMessage, @@ -471,6 +442,7 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(HashMap::new()), client_details_tx, client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -497,6 +469,7 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(map), client_details_tx, client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 8e72c26..2df9dd3 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -10,7 +10,7 @@ use crate::schema::{ InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams, ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest, ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams, - LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, + LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId, RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, @@ -175,27 +175,15 @@ pub trait McpClient: Sync + Send { request: RequestFromClient, timeout: Option, ) -> SdkResult { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let request_id = sender.next_request_id(); - - let mcp_message = - ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?; - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) + let response = self + .send(MessageFromClient::RequestFromClient(request), None, timeout) .await?; let server_message = response.ok_or_else(|| { RpcError::internal_error() - .with_message("An empty response was received from the server.".to_string()) + .with_message("An empty response was received from the client.".to_string()) })?; - let server_message = server_message.as_single()?; - if server_message.is_error() { return Err(server_message.as_error()?.error.into()); } @@ -205,27 +193,10 @@ pub trait McpClient: Sync + Send { async fn send( &self, - message: ClientMessage, + message: MessageFromClient, + request_id: Option, timeout: Option, - ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Single(message), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_single()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>; async fn send_batch( &self, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index a1d501d..220ea4d 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -2,8 +2,8 @@ use std::time::Duration; use crate::schema::{ schema_utils::{ - ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer, - RequestFromServer, ResultFromClient, ServerMessage, + ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, + ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, @@ -44,7 +44,7 @@ pub trait McpServer: Sync + Send { message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult>; + ) -> SdkResult>; async fn send_batch( &self, @@ -84,13 +84,11 @@ pub trait McpServer: Sync + Send { .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; - let client_messages = response.ok_or_else(|| { + let client_message = response.ok_or_else(|| { RpcError::internal_error() .with_message("An empty response was received from the client.".to_string()) })?; - let client_message = client_messages.as_single()?; - if client_message.is_error() { return Err(client_message.as_error()?.error.into()); } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 57a3ea8..564db0d 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -128,8 +128,10 @@ use futures::stream::Stream; // stream: &mut impl Stream>, pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), -) -> Option { + event_count: usize, +) -> Option> { let mut buffer = String::new(); + let mut events = vec![]; while let Some(item) = stream.next().await { match item { @@ -158,7 +160,10 @@ pub async fn read_sse_event_from_stream( // Return if data was found if let Some(data) = data { - return Some(data); + events.push(data); + if events.len().eq(&event_count) { + return Some(events); + } } } } @@ -171,9 +176,9 @@ pub async fn read_sse_event_from_stream( None } -pub async fn read_sse_event(response: Response) -> Option { +pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream).await + read_sse_event_from_stream(&mut stream, event_count).await } pub fn test_client_info() -> InitializeRequestParams { diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http.rs index 5eb5e47..23ca27f 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http.rs @@ -169,8 +169,8 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,8 +220,8 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -345,8 +345,8 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .await .unwrap(); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -365,7 +365,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { server.hyper_runtime.await_server().await.unwrap() } -// should establish standalone SSE stream and receive server-initiated messages +// should establish standalone SSE stream and receive server-initiated requests #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); @@ -394,48 +394,59 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { ); let hyper_server = Arc::new(server.hyper_runtime); - let hyper_server_clone = hyper_server.clone(); - let session_id_clone = session_id.to_string(); - - tokio::spawn(async move { - // Send a server-initiated notification that should appear on SSE stream with a valid request_id - hyper_server_clone - .list_roots(&session_id_clone, None) - .await - .unwrap(); - }); - - tokio::time::sleep(Duration::from_millis(2250)).await; - - let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( - RequestId::Integer(0), - ListRootsResult { - meta: None, - roots: vec![], - } - .into(), - ); - send_post_request( - &server.streamable_url, - &serde_json::to_string(&json_rpc_message).unwrap(), - Some(&session_id), - None, - ) - .await - .expect("Request failed"); + // Send two server-initiated request that should appear on SSE stream with a valid request_id + for _ in 0..2 { + let hyper_server_clone = hyper_server.clone(); + let session_id_clone = session_id.to_string(); + tokio::spawn(async move { + hyper_server_clone + .list_roots(&session_id_clone, None) + .await + .unwrap(); + }); + } + + for i in 0..2 { + // send responses back to the server for two server initiated requests + let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( + RequestId::Integer(i), + ListRootsResult { + meta: None, + roots: vec![], + } + .into(), + ); + send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + } - let event = read_sse_event(response).await.unwrap(); + // read two events from the sse stream + let events = read_sse_event(response, 2).await.unwrap(); - let message: ServerJsonrpcRequest = serde_json::from_str(&event).unwrap(); + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); - println!(">>> message {:?} ", message); + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request + else { + panic!("invalid message received!"); + }; - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message.request + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { panic!("invalid message received!"); }; + // ensure request_ids are unique + assert!(message2.id != message1.id); + hyper_server.graceful_shutdown(ONE_MILLISECOND); } @@ -461,7 +472,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let mut stream = response.bytes_stream(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -490,7 +501,7 @@ async fn should_not_close_get_sse_stream() { .await .unwrap(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -702,8 +713,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_1.status(), StatusCode::OK); assert_eq!(response_2.status(), StatusCode::OK); - let event = read_sse_event(response_2).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_2, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -718,8 +729,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() "Hello, Ali!" ); - let event = read_sse_event(response_1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_1, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1069,8 +1080,8 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { "text/event-stream" ); - let event = read_sse_event(response).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 1634922..26aa2b8 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -6,6 +6,7 @@ mod client_sse; pub mod error; mod mcp_stream; mod message_dispatcher; +mod request_id_gen; mod schema; #[cfg(feature = "sse")] mod sse; @@ -16,6 +17,7 @@ mod utils; #[cfg(feature = "sse")] pub use client_sse::*; pub use message_dispatcher::*; +pub use request_id_gen::*; #[cfg(feature = "sse")] pub use sse::*; pub use stdio::*; diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 2d2a377..08bdc21 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -5,12 +5,7 @@ use crate::{ utils::CancellationToken, IoStream, }; -use std::{ - collections::HashMap, - pin::Pin, - sync::{atomic::AtomicI64, Arc}, - time::Duration, -}; +use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use tokio::task::JoinHandle; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -57,12 +52,7 @@ impl MCPStream { // rpc message stream that receives incoming messages - let sender = MessageDispatcher::new( - pending_requests, - writable, - Arc::new(AtomicI64::new(0)), - request_timeout, - ); + let sender = MessageDispatcher::new(pending_requests, writable, request_timeout); (stream, sender, error_io) } diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 22d0b58..ea1eb04 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -10,7 +10,6 @@ use futures::future::join_all; use std::collections::HashMap; use std::pin::Pin; -use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; @@ -31,7 +30,6 @@ use crate::McpDispatch; pub struct MessageDispatcher { pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, } @@ -49,53 +47,15 @@ impl MessageDispatcher { pub fn new( pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, ) -> Self { Self { pending_requests, writable_std, - message_id_counter, request_timeout, } } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub fn request_id_for_message( - &self, - message: &impl McpMessage, - request_id: Option, - ) -> Option { - // we need to produce next request_id for requests - if message.is_request() { - // request_id should be None for requests - assert!(request_id.is_none()); - Some(self.next_request_id()) - } else if !message.is_notification() { - // `request_id` must not be `None` for errors, notifications and responses - assert!(request_id.is_some()); - request_id - } else { - None - } - } - pub fn next_request_id(&self) -> RequestId { - RequestId::Integer( - self.message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ) - } - async fn store_pending_request( &self, request_id: RequestId, diff --git a/crates/rust-mcp-transport/src/request_id_gen.rs b/crates/rust-mcp-transport/src/request_id_gen.rs new file mode 100644 index 0000000..598ab70 --- /dev/null +++ b/crates/rust-mcp-transport/src/request_id_gen.rs @@ -0,0 +1,99 @@ +use std::sync::atomic::AtomicI64; + +use crate::schema::{schema_utils::McpMessage, RequestId}; +use async_trait::async_trait; + +/// A trait for generating and managing request IDs in a thread-safe manner. +/// +/// Implementors provide functionality to generate unique request IDs, retrieve the last +/// generated ID, and reset the ID counter. +#[async_trait] +pub trait RequestIdGen: Send + Sync { + fn next_request_id(&self) -> RequestId; + fn last_request_id(&self) -> Option; + fn reset_to(&self, id: u64); + + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + fn request_id_for_message( + &self, + message: &dyn McpMessage, + request_id: Option, + ) -> Option { + // we need to produce next request_id for requests + if message.is_request() { + // request_id should be None for requests + assert!(request_id.is_none()); + Some(self.next_request_id()) + } else if !message.is_notification() { + // `request_id` must not be `None` for errors, notifications and responses + assert!(request_id.is_some()); + request_id + } else { + None + } + } +} + +pub struct RequestIdGenNumeric { + message_id_counter: AtomicI64, + last_message_id: AtomicI64, +} + +impl RequestIdGenNumeric { + pub fn new(initial_id: Option) -> Self { + Self { + message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), + last_message_id: AtomicI64::new(-1), + } + } +} + +impl RequestIdGen for RequestIdGenNumeric { + /// Generates the next unique request ID as an integer. + /// + /// Increments the internal counter atomically and updates the last generated ID. + /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. + fn next_request_id(&self) -> RequestId { + let id = self + .message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // Store the new ID as the last generated ID + self.last_message_id + .store(id, std::sync::atomic::Ordering::Relaxed); + RequestId::Integer(id) + } + + /// Returns the last generated request ID, if any. + /// + /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). + /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization + /// with other memory operations beyond atomicity. + fn last_request_id(&self) -> Option { + let last_id = self + .last_message_id + .load(std::sync::atomic::Ordering::Relaxed); + if last_id == -1 { + None + } else { + Some(RequestId::Integer(last_id)) + } + } + + /// Resets the internal counter to the specified ID. + /// + /// The provided `id` (u64) is converted to i64 and stored atomically. + fn reset_to(&self, id: u64) { + self.message_id_counter + .store(id as i64, std::sync::atomic::Ordering::Relaxed); + } +} From 91be09f5ab2c604848f903f26031382a113f0c45 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 19:14:22 -0300 Subject: [PATCH 09/33] refactor: updated the handler traits and fixed error handling for mcp error messages --- .../src/mcp_handlers/mcp_client_handler.rs | 2 +- .../src/mcp_handlers/mcp_client_handler_core.rs | 2 +- .../src/mcp_handlers/mcp_server_handler.rs | 2 +- .../src/mcp_handlers/mcp_server_handler_core.rs | 2 +- .../src/mcp_runtimes/client_runtime.rs | 16 ++++++++++++++-- .../client_runtime/mcp_client_runtime.rs | 2 +- .../client_runtime/mcp_client_runtime_core.rs | 2 +- .../src/mcp_runtimes/server_runtime.rs | 16 ++++++++++++++-- .../server_runtime/mcp_server_runtime.rs | 2 +- .../server_runtime/mcp_server_runtime_core.rs | 2 +- .../rust-mcp-sdk/src/mcp_traits/mcp_handler.rs | 14 ++++++++++---- .../hello-world-mcp-server-core/src/handler.rs | 2 +- .../src/handler.rs | 2 +- .../simple-mcp-client-core-sse/src/handler.rs | 2 +- examples/simple-mcp-client-core/src/handler.rs | 2 +- 15 files changed, 50 insertions(+), 20 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index f8ee1a0..c6fb208 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -148,7 +148,7 @@ pub trait ClientHandler: Send + Sync + 'static { //********************// async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index 3bbe5c9..a0afdf1 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -38,7 +38,7 @@ pub trait ClientHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP server. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError>; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index bf3fe17..89aebf5 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -319,7 +319,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index fffe2fc..e7b0e6d 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -45,7 +45,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP client. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError>; async fn on_server_started(&self, runtime: &dyn McpServer) { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 1bd1809..f0def82 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -130,7 +130,19 @@ impl ClientRuntime { None } ServerMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ServerMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request: {:?}", + &jsonrpc_error.id + ); + } None } ServerMessage::Response(response) => { @@ -140,7 +152,7 @@ impl ClientRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 9ccd4d9..7925f07 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -113,7 +113,7 @@ impl McpClientHandler for ClientInternalHandler> { /// Handles errors received from the server by passing the request to self.handler async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 3bdc318..8cb8cff 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -83,7 +83,7 @@ impl McpClientHandler for ClientCoreInternalHandler> async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index a118685..0672fd3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -261,7 +261,19 @@ impl ServerRuntime { None } ClientMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ClientMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request {:?}", + &jsonrpc_error.id + ); + } None } // The response is the result of a request, it is processed at the transport level. @@ -272,7 +284,7 @@ impl ServerRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 26f37e1..ea19e19 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -177,7 +177,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 154b4bc..e0e7108 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -87,7 +87,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> } async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index c86a623..2974bfc 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -24,8 +24,11 @@ pub trait McpServerHandler: Send + Sync { client_jsonrpc_request: RequestFromClient, runtime: &dyn McpServer, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpServer) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpServer, + ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, @@ -41,8 +44,11 @@ pub trait McpClientHandler: Send + Sync { server_jsonrpc_request: RequestFromServer, runtime: &dyn McpClient, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpClient) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpClient, + ) -> SdkResult<()>; async fn handle_notification( &self, server_jsonrpc_notification: NotificationFromServer, diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-core/src/handler.rs index fcde15e..f0bdefe 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-core/src/handler.rs @@ -98,7 +98,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, + error: &RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs index 53f884c..1c69e8c 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-core-streamable-http/src/handler.rs @@ -103,7 +103,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, + error: &RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-core-sse/src/handler.rs index a1a95e4..bd5e4fe 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-core-sse/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-core/src/handler.rs index a1a95e4..bd5e4fe 100644 --- a/examples/simple-mcp-client-core/src/handler.rs +++ b/examples/simple-mcp-client-core/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) From 5208284795b9f338d66f11e1b97bb47b5ba30ba9 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 19:17:51 -0300 Subject: [PATCH 10/33] refactor: move RequestIdGen trait to the sdk crate --- .../src/mcp_runtimes/client_runtime.rs | 19 ++-- .../src/mcp_runtimes/server_runtime.rs | 3 +- crates/rust-mcp-sdk/src/mcp_traits.rs | 3 + crates/rust-mcp-transport/src/lib.rs | 2 - .../rust-mcp-transport/src/request_id_gen.rs | 99 ------------------- 5 files changed, 15 insertions(+), 111 deletions(-) delete mode 100644 crates/rust-mcp-transport/src/request_id_gen.rs diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index f0def82..7ee0815 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,21 +1,22 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, +use crate::{ + mcp_traits::{RequestIdGen, RequestIdGenNumeric}, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, + ServerMessages, + }, + InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, + RequestId, RpcError, ServerResult, }, - InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RequestId, RpcError, ServerResult, }; use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; -use rust_mcp_transport::{ - IoStream, McpDispatch, MessageDispatcher, RequestIdGen, RequestIdGenNumeric, Transport, -}; +use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; use std::{ sync::{Arc, RwLock}, time::Duration, diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 0672fd3..49b5c3c 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -3,6 +3,7 @@ pub mod mcp_server_runtime_core; use crate::error::SdkResult; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; +use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric}; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, @@ -16,7 +17,7 @@ use futures::future::try_join_all; use futures::{StreamExt, TryFutureExt}; #[cfg(feature = "hyper-server")] use rust_mcp_transport::SessionId; -use rust_mcp_transport::{IoStream, RequestIdGen, RequestIdGenNumeric, TransportDispatcher}; +use rust_mcp_transport::{IoStream, TransportDispatcher}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 511731c..2b155fa 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -3,3 +3,6 @@ pub mod mcp_client; pub mod mcp_handler; #[cfg(feature = "server")] pub mod mcp_server; +mod request_id_gen; + +pub use request_id_gen::*; diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 26aa2b8..1634922 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -6,7 +6,6 @@ mod client_sse; pub mod error; mod mcp_stream; mod message_dispatcher; -mod request_id_gen; mod schema; #[cfg(feature = "sse")] mod sse; @@ -17,7 +16,6 @@ mod utils; #[cfg(feature = "sse")] pub use client_sse::*; pub use message_dispatcher::*; -pub use request_id_gen::*; #[cfg(feature = "sse")] pub use sse::*; pub use stdio::*; diff --git a/crates/rust-mcp-transport/src/request_id_gen.rs b/crates/rust-mcp-transport/src/request_id_gen.rs deleted file mode 100644 index 598ab70..0000000 --- a/crates/rust-mcp-transport/src/request_id_gen.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::sync::atomic::AtomicI64; - -use crate::schema::{schema_utils::McpMessage, RequestId}; -use async_trait::async_trait; - -/// A trait for generating and managing request IDs in a thread-safe manner. -/// -/// Implementors provide functionality to generate unique request IDs, retrieve the last -/// generated ID, and reset the ID counter. -#[async_trait] -pub trait RequestIdGen: Send + Sync { - fn next_request_id(&self) -> RequestId; - fn last_request_id(&self) -> Option; - fn reset_to(&self, id: u64); - - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - fn request_id_for_message( - &self, - message: &dyn McpMessage, - request_id: Option, - ) -> Option { - // we need to produce next request_id for requests - if message.is_request() { - // request_id should be None for requests - assert!(request_id.is_none()); - Some(self.next_request_id()) - } else if !message.is_notification() { - // `request_id` must not be `None` for errors, notifications and responses - assert!(request_id.is_some()); - request_id - } else { - None - } - } -} - -pub struct RequestIdGenNumeric { - message_id_counter: AtomicI64, - last_message_id: AtomicI64, -} - -impl RequestIdGenNumeric { - pub fn new(initial_id: Option) -> Self { - Self { - message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), - last_message_id: AtomicI64::new(-1), - } - } -} - -impl RequestIdGen for RequestIdGenNumeric { - /// Generates the next unique request ID as an integer. - /// - /// Increments the internal counter atomically and updates the last generated ID. - /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. - fn next_request_id(&self) -> RequestId { - let id = self - .message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // Store the new ID as the last generated ID - self.last_message_id - .store(id, std::sync::atomic::Ordering::Relaxed); - RequestId::Integer(id) - } - - /// Returns the last generated request ID, if any. - /// - /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). - /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization - /// with other memory operations beyond atomicity. - fn last_request_id(&self) -> Option { - let last_id = self - .last_message_id - .load(std::sync::atomic::Ordering::Relaxed); - if last_id == -1 { - None - } else { - Some(RequestId::Integer(last_id)) - } - } - - /// Resets the internal counter to the specified ID. - /// - /// The provided `id` (u64) is converted to i64 and stored atomically. - fn reset_to(&self, id: u64) { - self.message_id_counter - .store(id as i64, std::sync::atomic::Ordering::Relaxed); - } -} From acdd49988150d100788dcdda11310645428a13f0 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 19:18:07 -0300 Subject: [PATCH 11/33] chore: request_id_gen --- .../src/mcp_traits/request_id_gen.rs | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs diff --git a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs new file mode 100644 index 0000000..598ab70 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs @@ -0,0 +1,99 @@ +use std::sync::atomic::AtomicI64; + +use crate::schema::{schema_utils::McpMessage, RequestId}; +use async_trait::async_trait; + +/// A trait for generating and managing request IDs in a thread-safe manner. +/// +/// Implementors provide functionality to generate unique request IDs, retrieve the last +/// generated ID, and reset the ID counter. +#[async_trait] +pub trait RequestIdGen: Send + Sync { + fn next_request_id(&self) -> RequestId; + fn last_request_id(&self) -> Option; + fn reset_to(&self, id: u64); + + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + fn request_id_for_message( + &self, + message: &dyn McpMessage, + request_id: Option, + ) -> Option { + // we need to produce next request_id for requests + if message.is_request() { + // request_id should be None for requests + assert!(request_id.is_none()); + Some(self.next_request_id()) + } else if !message.is_notification() { + // `request_id` must not be `None` for errors, notifications and responses + assert!(request_id.is_some()); + request_id + } else { + None + } + } +} + +pub struct RequestIdGenNumeric { + message_id_counter: AtomicI64, + last_message_id: AtomicI64, +} + +impl RequestIdGenNumeric { + pub fn new(initial_id: Option) -> Self { + Self { + message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), + last_message_id: AtomicI64::new(-1), + } + } +} + +impl RequestIdGen for RequestIdGenNumeric { + /// Generates the next unique request ID as an integer. + /// + /// Increments the internal counter atomically and updates the last generated ID. + /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. + fn next_request_id(&self) -> RequestId { + let id = self + .message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // Store the new ID as the last generated ID + self.last_message_id + .store(id, std::sync::atomic::Ordering::Relaxed); + RequestId::Integer(id) + } + + /// Returns the last generated request ID, if any. + /// + /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). + /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization + /// with other memory operations beyond atomicity. + fn last_request_id(&self) -> Option { + let last_id = self + .last_message_id + .load(std::sync::atomic::Ordering::Relaxed); + if last_id == -1 { + None + } else { + Some(RequestId::Integer(last_id)) + } + } + + /// Resets the internal counter to the specified ID. + /// + /// The provided `id` (u64) is converted to i64 and stored atomically. + fn reset_to(&self, id: u64) { + self.message_id_counter + .store(id as i64, std::sync::atomic::Ordering::Relaxed); + } +} From 0d0b1bafd6acfbd945ab3299329b258e9f3ce788 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 19:23:22 -0300 Subject: [PATCH 12/33] chore: remove deprecated methods --- crates/rust-mcp-sdk/src/error.rs | 3 --- .../rust-mcp-sdk/src/mcp_macros/tool_box.rs | 9 ------- .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 25 ------------------- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 10 -------- .../src/mcp_traits/request_id_gen.rs | 2 ++ 5 files changed, 2 insertions(+), 47 deletions(-) diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 2feab67..3de8d98 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -41,6 +41,3 @@ impl McpSdkError { None } } - -#[deprecated(since = "0.2.0", note = "Use `McpSdkError` instead.")] -pub type MCPSdkError = McpSdkError; diff --git a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs index 3bd2735..a5b75d5 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs @@ -57,15 +57,6 @@ macro_rules! tool_box { )* ] } - - #[deprecated(since = "0.2.0", note = "Use `tools()` instead.")] - pub fn get_tools() -> Vec { - vec![ - $( - $tool::tool(), - )* - ] - } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 2df9dd3..1883581 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -35,16 +35,6 @@ pub trait McpClient: Sync + Send { fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> &InitializeRequestParams { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> Option { - self.server_info() - } - /// Checks whether the server has been initialized with client fn is_initialized(&self) -> bool { self.server_info().is_some() @@ -57,23 +47,12 @@ pub trait McpClient: Sync + Send { .map(|server_details| server_details.server_info) } - #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")] - fn get_server_version(&self) -> Option { - self.server_info() - .map(|server_details| server_details.server_info) - } - /// Returns the server's capabilities. /// After initialization has completed, this will be populated with the server's reported capabilities. fn server_capabilities(&self) -> Option { self.server_info().map(|item| item.capabilities) } - #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")] - fn get_server_capabilities(&self) -> Option { - self.server_info().map(|item| item.capabilities) - } - /// Checks if the server has tools available. /// /// This function retrieves the server information and checks if the @@ -156,10 +135,6 @@ pub trait McpClient: Sync + Send { self.server_info() .map(|server_details| server_details.capabilities.logging.is_some()) } - #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")] - fn get_instructions(&self) -> Option { - self.server_info()?.instructions - } fn instructions(&self) -> Option { self.server_info()?.instructions diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 220ea4d..0130c33 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -29,16 +29,6 @@ pub trait McpServer: Sync + Send { async fn wait_for_initialization(&self); - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> Option { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> &InitializeResult { - self.server_info() - } - async fn send( &self, message: MessageFromServer, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs index 598ab70..2372ae9 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs @@ -10,7 +10,9 @@ use async_trait::async_trait; #[async_trait] pub trait RequestIdGen: Send + Sync { fn next_request_id(&self) -> RequestId; + #[allow(unused)] fn last_request_id(&self) -> Option; + #[allow(unused)] fn reset_to(&self, id: u64); /// Determines the request ID for an outgoing MCP message. From c468742431c81a4f1db779bf8a22e9e780689cbf Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 19:58:56 -0300 Subject: [PATCH 13/33] chore: update sse client core example --- .../simple-mcp-client-core-sse/src/handler.rs | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-core-sse/src/handler.rs index bd5e4fe..ec2095e 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-core-sse/src/handler.rs @@ -41,11 +41,25 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_notification( &self, - _notification: NotificationFromServer, + notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - Err(RpcError::internal_error() - .with_message("handle_notification() Not implemented".to_string())) + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data.to_string() + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) } async fn handle_error( From 23f3aad9acf543568a59a93e6b24a83d9e229805 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 20:09:12 -0300 Subject: [PATCH 14/33] chore: update readme --- README.md | 1 + crates/rust-mcp-sdk/README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index ef5b4ed..1581d1d 100644 --- a/README.md +++ b/README.md @@ -526,6 +526,7 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index ef5b4ed..1581d1d 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -526,6 +526,7 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | From 3889c01a8b421d8d8765168089ea73404d22ad9c Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 19 Aug 2025 20:25:59 -0300 Subject: [PATCH 15/33] chore: update dependencies --- Cargo.lock | 82 ++++++++++--------- .../simple-mcp-client-core-sse/src/handler.rs | 2 +- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d51b2d6..9d4b91d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -118,7 +118,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "itoa", "matchit", @@ -170,7 +170,7 @@ dependencies = [ "fs-err", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "pin-project-lite", "rustls", @@ -239,9 +239,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29" [[package]] name = "bumpalo" @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.32" +version = "1.2.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e" +checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f" dependencies = [ "jobserver", "libc", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "cfg_aliases" @@ -870,13 +870,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", @@ -884,6 +885,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -896,7 +898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http 1.3.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "rustls", "rustls-pki-types", @@ -919,7 +921,7 @@ dependencies = [ "futures-util", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "ipnet", "libc", "percent-encoding", @@ -1385,9 +1387,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.36" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1395,9 +1397,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.97" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -1432,7 +1434,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls", "socket2 0.5.10", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "web-time", @@ -1453,7 +1455,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.14", + "thiserror 2.0.15", "tinyvec", "tracing", "web-time", @@ -1626,7 +1628,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-rustls", "hyper-util", "js-sys", @@ -1705,14 +1707,14 @@ dependencies = [ "axum", "axum-server", "futures", - "hyper 1.6.0", + "hyper 1.7.0", "reqwest", "rust-mcp-macros", "rust-mcp-schema", "rust-mcp-transport", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tokio-stream", "tracing", @@ -1731,7 +1733,7 @@ dependencies = [ "rust-mcp-schema", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tokio-stream", "tracing", @@ -1855,9 +1857,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.142" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "itoa", "memchr", @@ -1932,7 +1934,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", ] @@ -1946,7 +1948,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", ] @@ -1960,7 +1962,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "tracing-subscriber", @@ -1976,7 +1978,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "tracing-subscriber", @@ -2028,9 +2030,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -2068,11 +2070,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.14" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b0949c3a6c842cbde3f1686d6eea5a010516deb7085f79db747562d4102f41e" +checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" dependencies = [ - "thiserror-impl 2.0.14", + "thiserror-impl 2.0.15", ] [[package]] @@ -2088,9 +2090,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.14" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc5b44b4ab9c2fdd0e0512e6bece8388e214c0749f5862b114cc5b7a25daf227" +checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" dependencies = [ "proc-macro2", "quote", @@ -2149,9 +2151,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-core-sse/src/handler.rs index ec2095e..ab86e9e 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-core-sse/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { { println!( "Notification from server: {}", - logging_message_notification.params.data.to_string() + logging_message_notification.params.data ); } else { println!( From 1631fa1e576fa51e574ef99c7239bdfe3b909d6d Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Tue, 19 Aug 2025 20:40:13 -0300 Subject: [PATCH 16/33] Revert "refactor!: improve request ID generation, remove deprecated methods and adding improvements" --- Cargo.lock | 82 +++++++------- README.md | 1 - crates/rust-mcp-sdk/README.md | 1 - crates/rust-mcp-sdk/src/error.rs | 3 + .../src/mcp_handlers/mcp_client_handler.rs | 2 +- .../mcp_handlers/mcp_client_handler_core.rs | 2 +- .../src/mcp_handlers/mcp_server_handler.rs | 2 +- .../mcp_handlers/mcp_server_handler_core.rs | 2 +- .../rust-mcp-sdk/src/mcp_macros/tool_box.rs | 9 ++ .../src/mcp_runtimes/client_runtime.rs | 65 ++--------- .../client_runtime/mcp_client_runtime.rs | 2 +- .../client_runtime/mcp_client_runtime_core.rs | 2 +- .../src/mcp_runtimes/server_runtime.rs | 70 +++++++----- .../server_runtime/mcp_server_runtime.rs | 2 +- .../server_runtime/mcp_server_runtime_core.rs | 2 +- crates/rust-mcp-sdk/src/mcp_traits.rs | 3 - .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 68 +++++++++-- .../src/mcp_traits/mcp_handler.rs | 14 +-- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 20 +++- .../src/mcp_traits/request_id_gen.rs | 101 ----------------- crates/rust-mcp-sdk/tests/common/common.rs | 13 +-- .../tests/test_streamable_http.rs | 107 ++++++++---------- crates/rust-mcp-transport/src/mcp_stream.rs | 14 ++- .../src/message_dispatcher.rs | 40 +++++++ .../src/handler.rs | 2 +- .../src/handler.rs | 2 +- .../simple-mcp-client-core-sse/src/handler.rs | 22 +--- .../simple-mcp-client-core/src/handler.rs | 2 +- 28 files changed, 303 insertions(+), 352 deletions(-) delete mode 100644 crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs diff --git a/Cargo.lock b/Cargo.lock index 9d4b91d..d51b2d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.89" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", @@ -118,7 +118,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.7.0", + "hyper 1.6.0", "hyper-util", "itoa", "matchit", @@ -170,7 +170,7 @@ dependencies = [ "fs-err", "http 1.3.1", "http-body 1.0.1", - "hyper 1.7.0", + "hyper 1.6.0", "hyper-util", "pin-project-lite", "rustls", @@ -239,9 +239,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.2" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bumpalo" @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.33" +version = "1.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f" +checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e" dependencies = [ "jobserver", "libc", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.3" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" [[package]] name = "cfg_aliases" @@ -870,14 +870,13 @@ dependencies = [ [[package]] name = "hyper" -version = "1.7.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ - "atomic-waker", "bytes", "futures-channel", - "futures-core", + "futures-util", "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", @@ -885,7 +884,6 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "pin-utils", "smallvec", "tokio", "want", @@ -898,7 +896,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http 1.3.1", - "hyper 1.7.0", + "hyper 1.6.0", "hyper-util", "rustls", "rustls-pki-types", @@ -921,7 +919,7 @@ dependencies = [ "futures-util", "http 1.3.1", "http-body 1.0.1", - "hyper 1.7.0", + "hyper 1.6.0", "ipnet", "libc", "percent-encoding", @@ -1387,9 +1385,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.37" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn", @@ -1397,9 +1395,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.101" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" dependencies = [ "unicode-ident", ] @@ -1434,7 +1432,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls", "socket2 0.5.10", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", "tracing", "web-time", @@ -1455,7 +1453,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.15", + "thiserror 2.0.14", "tinyvec", "tracing", "web-time", @@ -1628,7 +1626,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.7.0", + "hyper 1.6.0", "hyper-rustls", "hyper-util", "js-sys", @@ -1707,14 +1705,14 @@ dependencies = [ "axum", "axum-server", "futures", - "hyper 1.7.0", + "hyper 1.6.0", "reqwest", "rust-mcp-macros", "rust-mcp-schema", "rust-mcp-transport", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", "tokio-stream", "tracing", @@ -1733,7 +1731,7 @@ dependencies = [ "rust-mcp-schema", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", "tokio-stream", "tracing", @@ -1857,9 +1855,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.143" +version = "1.0.142" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7" dependencies = [ "itoa", "memchr", @@ -1934,7 +1932,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", ] @@ -1948,7 +1946,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", ] @@ -1962,7 +1960,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", "tracing", "tracing-subscriber", @@ -1978,7 +1976,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.14", "tokio", "tracing", "tracing-subscriber", @@ -2030,9 +2028,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.106" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -2070,11 +2068,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.15" +version = "2.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" +checksum = "0b0949c3a6c842cbde3f1686d6eea5a010516deb7085f79db747562d4102f41e" dependencies = [ - "thiserror-impl 2.0.15", + "thiserror-impl 2.0.14", ] [[package]] @@ -2090,9 +2088,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.15" +version = "2.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" +checksum = "cc5b44b4ab9c2fdd0e0512e6bece8388e214c0749f5862b114cc5b7a25daf227" dependencies = [ "proc-macro2", "quote", @@ -2151,9 +2149,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" dependencies = [ "tinyvec_macros", ] diff --git a/README.md b/README.md index 1581d1d..ef5b4ed 100644 --- a/README.md +++ b/README.md @@ -526,7 +526,6 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | -| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 1581d1d..ef5b4ed 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -526,7 +526,6 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | -| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 3de8d98..2feab67 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -41,3 +41,6 @@ impl McpSdkError { None } } + +#[deprecated(since = "0.2.0", note = "Use `McpSdkError` instead.")] +pub type MCPSdkError = McpSdkError; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index c6fb208..f8ee1a0 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -148,7 +148,7 @@ pub trait ClientHandler: Send + Sync + 'static { //********************// async fn handle_error( &self, - error: &RpcError, + error: RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index a0afdf1..3bbe5c9 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -38,7 +38,7 @@ pub trait ClientHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP server. async fn handle_error( &self, - error: &RpcError, + error: RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError>; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 89aebf5..bf3fe17 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -319,7 +319,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_error( &self, - error: &RpcError, + error: RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index e7b0e6d..fffe2fc 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -45,7 +45,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP client. async fn handle_error( &self, - error: &RpcError, + error: RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError>; async fn on_server_started(&self, runtime: &dyn McpServer) { diff --git a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs index a5b75d5..3bd2735 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs @@ -57,6 +57,15 @@ macro_rules! tool_box { )* ] } + + #[deprecated(since = "0.2.0", note = "Use `tools()` instead.")] + pub fn get_tools() -> Vec { + vec![ + $( + $tool::tool(), + )* + ] + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 7ee0815..8d113c3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,26 +1,20 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; -use crate::{ - mcp_traits::{RequestIdGen, RequestIdGenNumeric}, - schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, - }, - InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RequestId, RpcError, ServerResult, +use crate::schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, + ServerMessages, }, + InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, + RpcError, ServerResult, }; use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::{ - sync::{Arc, RwLock}, - time::Duration, -}; +use std::sync::{Arc, RwLock}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::Mutex; @@ -47,7 +41,6 @@ pub struct ClientRuntime { // Details about the connected server server_details: Arc>>, handlers: Mutex>>>, - request_id_gen: Box, } impl ClientRuntime { @@ -68,7 +61,6 @@ impl ClientRuntime { client_details, server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), - request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -131,19 +123,7 @@ impl ClientRuntime { None } ServerMessage::Error(jsonrpc_error) => { - self.handler - .handle_error(&jsonrpc_error.error, self) - .await?; - if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { - tx_response - .send(ServerMessage::Error(jsonrpc_error)) - .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; - } else { - tracing::warn!( - "Received an error response with no corresponding request: {:?}", - &jsonrpc_error.id - ); - } + self.handler.handle_error(jsonrpc_error.error, self).await?; None } ServerMessage::Response(response) => { @@ -153,7 +133,7 @@ impl ClientRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received a response with no corresponding request: {:?}", + "Received response or error without a matching request: {:?}", &response.id ); } @@ -304,33 +284,6 @@ impl McpClient for ClientRuntime { } } - async fn send( - &self, - message: MessageFromClient, - request_id: Option, - timeout: Option, - ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let outgoing_request_id = self - .request_id_gen - .request_id_for_message(&message, request_id); - - let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; - - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) - .await? - .map(|res| res.as_single()) - .transpose()?; - - Ok(response) - } - async fn is_shut_down(&self) -> bool { self.transport.is_shut_down().await } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 7925f07..9ccd4d9 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -113,7 +113,7 @@ impl McpClientHandler for ClientInternalHandler> { /// Handles errors received from the server by passing the request to self.handler async fn handle_error( &self, - jsonrpc_error: &RpcError, + jsonrpc_error: RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 8cb8cff..3bdc318 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -83,7 +83,7 @@ impl McpClientHandler for ClientCoreInternalHandler> async fn handle_error( &self, - jsonrpc_error: &RpcError, + jsonrpc_error: RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 49b5c3c..d787a10 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -3,7 +3,6 @@ pub mod mcp_server_runtime_core; use crate::error::SdkResult; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; -use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric}; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, @@ -46,7 +45,6 @@ pub struct ServerRuntime { #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, - request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, } @@ -81,7 +79,7 @@ impl McpServer for ServerRuntime { message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult> { + ) -> SdkResult> { let transport_map = self.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -89,18 +87,14 @@ impl McpServer for ServerRuntime { )?; let outgoing_request_id = self - .request_id_gen - .request_id_for_message(&message, request_id); + .request_id_for_message(transport, &message, request_id) + .await; let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; - - let response = transport + transport .send_message(ServerMessages::Single(mcp_message), request_timeout) - .await? - .map(|res| res.as_single()) - .transpose()?; - - Ok(response) + .map_err(|err| err.into()) + .await } async fn send_batch( @@ -217,6 +211,40 @@ impl ServerRuntime { Ok(()) } + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + pub(crate) async fn request_id_for_message( + &self, + transport: &Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, + >, + message: &MessageFromServer, + request_id: Option, + ) -> Option { + let message_sender = transport.message_sender(); + let guard = message_sender.read().await; + if let Some(dispatcher) = guard.as_ref() { + dispatcher.request_id_for_message(message, request_id) + } else { + None + } + } + pub(crate) async fn handle_message( &self, message: ClientMessage, @@ -262,19 +290,7 @@ impl ServerRuntime { None } ClientMessage::Error(jsonrpc_error) => { - self.handler - .handle_error(&jsonrpc_error.error, self) - .await?; - if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { - tx_response - .send(ClientMessage::Error(jsonrpc_error)) - .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; - } else { - tracing::warn!( - "Received an error response with no corresponding request {:?}", - &jsonrpc_error.id - ); - } + self.handler.handle_error(jsonrpc_error.error, self).await?; None } // The response is the result of a request, it is processed at the transport level. @@ -285,7 +301,7 @@ impl ServerRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received a response with no corresponding request: {:?}", + "Received response or error without a matching request: {:?}", &response.id ); } @@ -455,7 +471,6 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(HashMap::new()), client_details_tx, client_details_rx, - request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -482,7 +497,6 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(map), client_details_tx, client_details_rx, - request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index ea19e19..26f37e1 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -177,7 +177,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, - jsonrpc_error: &RpcError, + jsonrpc_error: RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index e0e7108..154b4bc 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -87,7 +87,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> } async fn handle_error( &self, - jsonrpc_error: &RpcError, + jsonrpc_error: RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 2b155fa..511731c 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -3,6 +3,3 @@ pub mod mcp_client; pub mod mcp_handler; #[cfg(feature = "server")] pub mod mcp_server; -mod request_id_gen; - -pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 1883581..8e72c26 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -10,7 +10,7 @@ use crate::schema::{ InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams, ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest, ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams, - LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId, + LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, @@ -35,6 +35,16 @@ pub trait McpClient: Sync + Send { fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; + #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] + fn get_client_info(&self) -> &InitializeRequestParams { + self.client_info() + } + + #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] + fn get_server_info(&self) -> Option { + self.server_info() + } + /// Checks whether the server has been initialized with client fn is_initialized(&self) -> bool { self.server_info().is_some() @@ -47,12 +57,23 @@ pub trait McpClient: Sync + Send { .map(|server_details| server_details.server_info) } + #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")] + fn get_server_version(&self) -> Option { + self.server_info() + .map(|server_details| server_details.server_info) + } + /// Returns the server's capabilities. /// After initialization has completed, this will be populated with the server's reported capabilities. fn server_capabilities(&self) -> Option { self.server_info().map(|item| item.capabilities) } + #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")] + fn get_server_capabilities(&self) -> Option { + self.server_info().map(|item| item.capabilities) + } + /// Checks if the server has tools available. /// /// This function retrieves the server information and checks if the @@ -135,6 +156,10 @@ pub trait McpClient: Sync + Send { self.server_info() .map(|server_details| server_details.capabilities.logging.is_some()) } + #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")] + fn get_instructions(&self) -> Option { + self.server_info()?.instructions + } fn instructions(&self) -> Option { self.server_info()?.instructions @@ -150,15 +175,27 @@ pub trait McpClient: Sync + Send { request: RequestFromClient, timeout: Option, ) -> SdkResult { - let response = self - .send(MessageFromClient::RequestFromClient(request), None, timeout) + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let request_id = sender.next_request_id(); + + let mcp_message = + ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?; + let response = sender + .send_message(ClientMessages::Single(mcp_message), timeout) .await?; let server_message = response.ok_or_else(|| { RpcError::internal_error() - .with_message("An empty response was received from the client.".to_string()) + .with_message("An empty response was received from the server.".to_string()) })?; + let server_message = server_message.as_single()?; + if server_message.is_error() { return Err(server_message.as_error()?.error.into()); } @@ -168,10 +205,27 @@ pub trait McpClient: Sync + Send { async fn send( &self, - message: MessageFromClient, - request_id: Option, + message: ClientMessage, timeout: Option, - ) -> SdkResult>; + ) -> SdkResult> { + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let response = sender + .send_message(ClientMessages::Single(message), timeout) + .await?; + + match response { + Some(res) => { + let server_results = res.as_single()?; + Ok(Some(server_results)) + } + None => Ok(None), + } + } async fn send_batch( &self, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index 2974bfc..c86a623 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -24,11 +24,8 @@ pub trait McpServerHandler: Send + Sync { client_jsonrpc_request: RequestFromClient, runtime: &dyn McpServer, ) -> std::result::Result; - async fn handle_error( - &self, - jsonrpc_error: &RpcError, - runtime: &dyn McpServer, - ) -> SdkResult<()>; + async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpServer) + -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, @@ -44,11 +41,8 @@ pub trait McpClientHandler: Send + Sync { server_jsonrpc_request: RequestFromServer, runtime: &dyn McpClient, ) -> std::result::Result; - async fn handle_error( - &self, - jsonrpc_error: &RpcError, - runtime: &dyn McpClient, - ) -> SdkResult<()>; + async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpClient) + -> SdkResult<()>; async fn handle_notification( &self, server_jsonrpc_notification: NotificationFromServer, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 0130c33..a1d501d 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -2,8 +2,8 @@ use std::time::Duration; use crate::schema::{ schema_utils::{ - ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, - ResultFromClient, ServerMessage, + ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer, + RequestFromServer, ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, @@ -29,12 +29,22 @@ pub trait McpServer: Sync + Send { async fn wait_for_initialization(&self); + #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] + fn get_client_info(&self) -> Option { + self.client_info() + } + + #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] + fn get_server_info(&self) -> &InitializeResult { + self.server_info() + } + async fn send( &self, message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult>; + ) -> SdkResult>; async fn send_batch( &self, @@ -74,11 +84,13 @@ pub trait McpServer: Sync + Send { .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; - let client_message = response.ok_or_else(|| { + let client_messages = response.ok_or_else(|| { RpcError::internal_error() .with_message("An empty response was received from the client.".to_string()) })?; + let client_message = client_messages.as_single()?; + if client_message.is_error() { return Err(client_message.as_error()?.error.into()); } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs deleted file mode 100644 index 2372ae9..0000000 --- a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::atomic::AtomicI64; - -use crate::schema::{schema_utils::McpMessage, RequestId}; -use async_trait::async_trait; - -/// A trait for generating and managing request IDs in a thread-safe manner. -/// -/// Implementors provide functionality to generate unique request IDs, retrieve the last -/// generated ID, and reset the ID counter. -#[async_trait] -pub trait RequestIdGen: Send + Sync { - fn next_request_id(&self) -> RequestId; - #[allow(unused)] - fn last_request_id(&self) -> Option; - #[allow(unused)] - fn reset_to(&self, id: u64); - - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - fn request_id_for_message( - &self, - message: &dyn McpMessage, - request_id: Option, - ) -> Option { - // we need to produce next request_id for requests - if message.is_request() { - // request_id should be None for requests - assert!(request_id.is_none()); - Some(self.next_request_id()) - } else if !message.is_notification() { - // `request_id` must not be `None` for errors, notifications and responses - assert!(request_id.is_some()); - request_id - } else { - None - } - } -} - -pub struct RequestIdGenNumeric { - message_id_counter: AtomicI64, - last_message_id: AtomicI64, -} - -impl RequestIdGenNumeric { - pub fn new(initial_id: Option) -> Self { - Self { - message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), - last_message_id: AtomicI64::new(-1), - } - } -} - -impl RequestIdGen for RequestIdGenNumeric { - /// Generates the next unique request ID as an integer. - /// - /// Increments the internal counter atomically and updates the last generated ID. - /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. - fn next_request_id(&self) -> RequestId { - let id = self - .message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // Store the new ID as the last generated ID - self.last_message_id - .store(id, std::sync::atomic::Ordering::Relaxed); - RequestId::Integer(id) - } - - /// Returns the last generated request ID, if any. - /// - /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). - /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization - /// with other memory operations beyond atomicity. - fn last_request_id(&self) -> Option { - let last_id = self - .last_message_id - .load(std::sync::atomic::Ordering::Relaxed); - if last_id == -1 { - None - } else { - Some(RequestId::Integer(last_id)) - } - } - - /// Resets the internal counter to the specified ID. - /// - /// The provided `id` (u64) is converted to i64 and stored atomically. - fn reset_to(&self, id: u64) { - self.message_id_counter - .store(id as i64, std::sync::atomic::Ordering::Relaxed); - } -} diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 564db0d..57a3ea8 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -128,10 +128,8 @@ use futures::stream::Stream; // stream: &mut impl Stream>, pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), - event_count: usize, -) -> Option> { +) -> Option { let mut buffer = String::new(); - let mut events = vec![]; while let Some(item) = stream.next().await { match item { @@ -160,10 +158,7 @@ pub async fn read_sse_event_from_stream( // Return if data was found if let Some(data) = data { - events.push(data); - if events.len().eq(&event_count) { - return Some(events); - } + return Some(data); } } } @@ -176,9 +171,9 @@ pub async fn read_sse_event_from_stream( None } -pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { +pub async fn read_sse_event(response: Response) -> Option { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream, event_count).await + read_sse_event_from_stream(&mut stream).await } pub fn test_client_info() -> InitializeRequestParams { diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http.rs index 23ca27f..5eb5e47 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http.rs @@ -169,8 +169,8 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); - let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,8 +220,8 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); - let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -345,8 +345,8 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .await .unwrap(); - let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -365,7 +365,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { server.hyper_runtime.await_server().await.unwrap() } -// should establish standalone SSE stream and receive server-initiated requests +// should establish standalone SSE stream and receive server-initiated messages #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); @@ -394,59 +394,48 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { ); let hyper_server = Arc::new(server.hyper_runtime); + let hyper_server_clone = hyper_server.clone(); + let session_id_clone = session_id.to_string(); + + tokio::spawn(async move { + // Send a server-initiated notification that should appear on SSE stream with a valid request_id + hyper_server_clone + .list_roots(&session_id_clone, None) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(2250)).await; + + let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( + RequestId::Integer(0), + ListRootsResult { + meta: None, + roots: vec![], + } + .into(), + ); - // Send two server-initiated request that should appear on SSE stream with a valid request_id - for _ in 0..2 { - let hyper_server_clone = hyper_server.clone(); - let session_id_clone = session_id.to_string(); - tokio::spawn(async move { - hyper_server_clone - .list_roots(&session_id_clone, None) - .await - .unwrap(); - }); - } - - for i in 0..2 { - // send responses back to the server for two server initiated requests - let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( - RequestId::Integer(i), - ListRootsResult { - meta: None, - roots: vec![], - } - .into(), - ); - send_post_request( - &server.streamable_url, - &serde_json::to_string(&json_rpc_message).unwrap(), - Some(&session_id), - None, - ) - .await - .expect("Request failed"); - } - - // read two events from the sse stream - let events = read_sse_event(response, 2).await.unwrap(); + send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); - let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response).await.unwrap(); - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request - else { - panic!("invalid message received!"); - }; + let message: ServerJsonrpcRequest = serde_json::from_str(&event).unwrap(); - let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + println!(">>> message {:?} ", message); - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message.request else { panic!("invalid message received!"); }; - // ensure request_ids are unique - assert!(message2.id != message1.id); - hyper_server.graceful_shutdown(ONE_MILLISECOND); } @@ -472,7 +461,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let mut stream = response.bytes_stream(); - let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); + let event = read_sse_event_from_stream(&mut stream).await.unwrap(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -501,7 +490,7 @@ async fn should_not_close_get_sse_stream() { .await .unwrap(); - let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); + let event = read_sse_event_from_stream(&mut stream).await.unwrap(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -713,8 +702,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_1.status(), StatusCode::OK); assert_eq!(response_2.status(), StatusCode::OK); - let events = read_sse_event(response_2, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response_2).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -729,8 +718,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() "Hello, Ali!" ); - let events = read_sse_event(response_1, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response_1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1080,8 +1069,8 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { "text/event-stream" ); - let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); + let event = read_sse_event(response).await.unwrap(); + let message: ServerMessages = serde_json::from_str(&event).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 08bdc21..2d2a377 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -5,7 +5,12 @@ use crate::{ utils::CancellationToken, IoStream, }; -use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; +use std::{ + collections::HashMap, + pin::Pin, + sync::{atomic::AtomicI64, Arc}, + time::Duration, +}; use tokio::task::JoinHandle; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -52,7 +57,12 @@ impl MCPStream { // rpc message stream that receives incoming messages - let sender = MessageDispatcher::new(pending_requests, writable, request_timeout); + let sender = MessageDispatcher::new( + pending_requests, + writable, + Arc::new(AtomicI64::new(0)), + request_timeout, + ); (stream, sender, error_io) } diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index ea1eb04..22d0b58 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -10,6 +10,7 @@ use futures::future::join_all; use std::collections::HashMap; use std::pin::Pin; +use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; @@ -30,6 +31,7 @@ use crate::McpDispatch; pub struct MessageDispatcher { pending_requests: Arc>>>, writable_std: Mutex>>, + message_id_counter: Arc, request_timeout: Duration, } @@ -47,15 +49,53 @@ impl MessageDispatcher { pub fn new( pending_requests: Arc>>>, writable_std: Mutex>>, + message_id_counter: Arc, request_timeout: Duration, ) -> Self { Self { pending_requests, writable_std, + message_id_counter, request_timeout, } } + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + pub fn request_id_for_message( + &self, + message: &impl McpMessage, + request_id: Option, + ) -> Option { + // we need to produce next request_id for requests + if message.is_request() { + // request_id should be None for requests + assert!(request_id.is_none()); + Some(self.next_request_id()) + } else if !message.is_notification() { + // `request_id` must not be `None` for errors, notifications and responses + assert!(request_id.is_some()); + request_id + } else { + None + } + } + pub fn next_request_id(&self) -> RequestId { + RequestId::Integer( + self.message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + } + async fn store_pending_request( &self, request_id: RequestId, diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-core/src/handler.rs index f0bdefe..fcde15e 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-core/src/handler.rs @@ -98,7 +98,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: &RpcError, + error: RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs index 1c69e8c..53f884c 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-core-streamable-http/src/handler.rs @@ -103,7 +103,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: &RpcError, + error: RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-core-sse/src/handler.rs index ab86e9e..a1a95e4 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-core-sse/src/handler.rs @@ -41,30 +41,16 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_notification( &self, - notification: NotificationFromServer, + _notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - if let NotificationFromServer::ServerNotification( - schema::ServerNotification::LoggingMessageNotification(logging_message_notification), - ) = notification - { - println!( - "Notification from server: {}", - logging_message_notification.params.data - ); - } else { - println!( - "A {} notification received from the server", - notification.method() - ); - }; - - Ok(()) + Err(RpcError::internal_error() + .with_message("handle_notification() Not implemented".to_string())) } async fn handle_error( &self, - _error: &RpcError, + _error: RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-core/src/handler.rs index bd5e4fe..a1a95e4 100644 --- a/examples/simple-mcp-client-core/src/handler.rs +++ b/examples/simple-mcp-client-core/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_error( &self, - _error: &RpcError, + _error: RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) From 95b91aad191e1b8777ca4a02612ab9183e0276d3 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Tue, 19 Aug 2025 20:57:18 -0300 Subject: [PATCH 17/33] feat!: improve request ID generation, remove deprecated methods and adding improvements --- Cargo.lock | 82 +++++++------- README.md | 1 + crates/rust-mcp-sdk/README.md | 1 + crates/rust-mcp-sdk/src/error.rs | 3 - .../src/mcp_handlers/mcp_client_handler.rs | 2 +- .../mcp_handlers/mcp_client_handler_core.rs | 2 +- .../src/mcp_handlers/mcp_server_handler.rs | 2 +- .../mcp_handlers/mcp_server_handler_core.rs | 2 +- .../rust-mcp-sdk/src/mcp_macros/tool_box.rs | 9 -- .../src/mcp_runtimes/client_runtime.rs | 65 +++++++++-- .../client_runtime/mcp_client_runtime.rs | 2 +- .../client_runtime/mcp_client_runtime_core.rs | 2 +- .../src/mcp_runtimes/server_runtime.rs | 70 +++++------- .../server_runtime/mcp_server_runtime.rs | 2 +- .../server_runtime/mcp_server_runtime_core.rs | 2 +- crates/rust-mcp-sdk/src/mcp_traits.rs | 3 + .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 68 ++--------- .../src/mcp_traits/mcp_handler.rs | 14 ++- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 20 +--- .../src/mcp_traits/request_id_gen.rs | 101 +++++++++++++++++ crates/rust-mcp-sdk/tests/common/common.rs | 13 ++- .../tests/test_streamable_http.rs | 107 ++++++++++-------- crates/rust-mcp-transport/src/mcp_stream.rs | 14 +-- .../src/message_dispatcher.rs | 40 ------- .../src/handler.rs | 2 +- .../src/handler.rs | 2 +- .../simple-mcp-client-core-sse/src/handler.rs | 22 +++- .../simple-mcp-client-core/src/handler.rs | 2 +- 28 files changed, 352 insertions(+), 303 deletions(-) create mode 100644 crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs diff --git a/Cargo.lock b/Cargo.lock index d51b2d6..9d4b91d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -118,7 +118,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "itoa", "matchit", @@ -170,7 +170,7 @@ dependencies = [ "fs-err", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "pin-project-lite", "rustls", @@ -239,9 +239,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29" [[package]] name = "bumpalo" @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.32" +version = "1.2.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e" +checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f" dependencies = [ "jobserver", "libc", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "cfg_aliases" @@ -870,13 +870,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", @@ -884,6 +885,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -896,7 +898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http 1.3.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "rustls", "rustls-pki-types", @@ -919,7 +921,7 @@ dependencies = [ "futures-util", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "ipnet", "libc", "percent-encoding", @@ -1385,9 +1387,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.36" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1395,9 +1397,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.97" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -1432,7 +1434,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls", "socket2 0.5.10", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "web-time", @@ -1453,7 +1455,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.14", + "thiserror 2.0.15", "tinyvec", "tracing", "web-time", @@ -1626,7 +1628,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-rustls", "hyper-util", "js-sys", @@ -1705,14 +1707,14 @@ dependencies = [ "axum", "axum-server", "futures", - "hyper 1.6.0", + "hyper 1.7.0", "reqwest", "rust-mcp-macros", "rust-mcp-schema", "rust-mcp-transport", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tokio-stream", "tracing", @@ -1731,7 +1733,7 @@ dependencies = [ "rust-mcp-schema", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tokio-stream", "tracing", @@ -1855,9 +1857,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.142" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "itoa", "memchr", @@ -1932,7 +1934,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", ] @@ -1946,7 +1948,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", ] @@ -1960,7 +1962,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "tracing-subscriber", @@ -1976,7 +1978,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.15", "tokio", "tracing", "tracing-subscriber", @@ -2028,9 +2030,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -2068,11 +2070,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.14" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b0949c3a6c842cbde3f1686d6eea5a010516deb7085f79db747562d4102f41e" +checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" dependencies = [ - "thiserror-impl 2.0.14", + "thiserror-impl 2.0.15", ] [[package]] @@ -2088,9 +2090,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.14" +version = "2.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc5b44b4ab9c2fdd0e0512e6bece8388e214c0749f5862b114cc5b7a25daf227" +checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" dependencies = [ "proc-macro2", "quote", @@ -2149,9 +2151,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] diff --git a/README.md b/README.md index ef5b4ed..1581d1d 100644 --- a/README.md +++ b/README.md @@ -526,6 +526,7 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index ef5b4ed..1581d1d 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -526,6 +526,7 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 2feab67..3de8d98 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -41,6 +41,3 @@ impl McpSdkError { None } } - -#[deprecated(since = "0.2.0", note = "Use `McpSdkError` instead.")] -pub type MCPSdkError = McpSdkError; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index f8ee1a0..c6fb208 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -148,7 +148,7 @@ pub trait ClientHandler: Send + Sync + 'static { //********************// async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index 3bbe5c9..a0afdf1 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -38,7 +38,7 @@ pub trait ClientHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP server. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError>; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index bf3fe17..89aebf5 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -319,7 +319,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index fffe2fc..e7b0e6d 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -45,7 +45,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP client. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpServer, ) -> std::result::Result<(), RpcError>; async fn on_server_started(&self, runtime: &dyn McpServer) { diff --git a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs index 3bd2735..a5b75d5 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs @@ -57,15 +57,6 @@ macro_rules! tool_box { )* ] } - - #[deprecated(since = "0.2.0", note = "Use `tools()` instead.")] - pub fn get_tools() -> Vec { - vec![ - $( - $tool::tool(), - )* - ] - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 8d113c3..7ee0815 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,20 +1,26 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, +use crate::{ + mcp_traits::{RequestIdGen, RequestIdGenNumeric}, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, + ServerMessages, + }, + InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, + RequestId, RpcError, ServerResult, }, - InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RpcError, ServerResult, }; use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::sync::{Arc, RwLock}; +use std::{ + sync::{Arc, RwLock}, + time::Duration, +}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::sync::Mutex; @@ -41,6 +47,7 @@ pub struct ClientRuntime { // Details about the connected server server_details: Arc>>, handlers: Mutex>>>, + request_id_gen: Box, } impl ClientRuntime { @@ -61,6 +68,7 @@ impl ClientRuntime { client_details, server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -123,7 +131,19 @@ impl ClientRuntime { None } ServerMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ServerMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request: {:?}", + &jsonrpc_error.id + ); + } None } ServerMessage::Response(response) => { @@ -133,7 +153,7 @@ impl ClientRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } @@ -284,6 +304,33 @@ impl McpClient for ClientRuntime { } } + async fn send( + &self, + message: MessageFromClient, + request_id: Option, + timeout: Option, + ) -> SdkResult> { + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = sender + .send_message(ClientMessages::Single(mcp_message), timeout) + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) + } + async fn is_shut_down(&self) -> bool { self.transport.is_shut_down().await } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 9ccd4d9..7925f07 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -113,7 +113,7 @@ impl McpClientHandler for ClientInternalHandler> { /// Handles errors received from the server by passing the request to self.handler async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 3bdc318..8cb8cff 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -83,7 +83,7 @@ impl McpClientHandler for ClientCoreInternalHandler> async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d787a10..49b5c3c 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -3,6 +3,7 @@ pub mod mcp_server_runtime_core; use crate::error::SdkResult; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; +use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric}; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, @@ -45,6 +46,7 @@ pub struct ServerRuntime { #[cfg(feature = "hyper-server")] session_id: Option, transport_map: tokio::sync::RwLock>, + request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, } @@ -79,7 +81,7 @@ impl McpServer for ServerRuntime { message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult> { + ) -> SdkResult> { let transport_map = self.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -87,14 +89,18 @@ impl McpServer for ServerRuntime { )?; let outgoing_request_id = self - .request_id_for_message(transport, &message, request_id) - .await; + .request_id_gen + .request_id_for_message(&message, request_id); let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; - transport + + let response = transport .send_message(ServerMessages::Single(mcp_message), request_timeout) - .map_err(|err| err.into()) - .await + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) } async fn send_batch( @@ -211,40 +217,6 @@ impl ServerRuntime { Ok(()) } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub(crate) async fn request_id_for_message( - &self, - transport: &Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - message: &MessageFromServer, - request_id: Option, - ) -> Option { - let message_sender = transport.message_sender(); - let guard = message_sender.read().await; - if let Some(dispatcher) = guard.as_ref() { - dispatcher.request_id_for_message(message, request_id) - } else { - None - } - } - pub(crate) async fn handle_message( &self, message: ClientMessage, @@ -290,7 +262,19 @@ impl ServerRuntime { None } ClientMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ClientMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request {:?}", + &jsonrpc_error.id + ); + } None } // The response is the result of a request, it is processed at the transport level. @@ -301,7 +285,7 @@ impl ServerRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } @@ -471,6 +455,7 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(HashMap::new()), client_details_tx, client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } @@ -497,6 +482,7 @@ impl ServerRuntime { transport_map: tokio::sync::RwLock::new(map), client_details_tx, client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), } } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 26f37e1..ea19e19 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -177,7 +177,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 154b4bc..e0e7108 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -87,7 +87,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> } async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpServer, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 511731c..2b155fa 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -3,3 +3,6 @@ pub mod mcp_client; pub mod mcp_handler; #[cfg(feature = "server")] pub mod mcp_server; +mod request_id_gen; + +pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 8e72c26..1883581 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -10,7 +10,7 @@ use crate::schema::{ InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams, ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest, ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams, - LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, + LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId, RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, @@ -35,16 +35,6 @@ pub trait McpClient: Sync + Send { fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> &InitializeRequestParams { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> Option { - self.server_info() - } - /// Checks whether the server has been initialized with client fn is_initialized(&self) -> bool { self.server_info().is_some() @@ -57,23 +47,12 @@ pub trait McpClient: Sync + Send { .map(|server_details| server_details.server_info) } - #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")] - fn get_server_version(&self) -> Option { - self.server_info() - .map(|server_details| server_details.server_info) - } - /// Returns the server's capabilities. /// After initialization has completed, this will be populated with the server's reported capabilities. fn server_capabilities(&self) -> Option { self.server_info().map(|item| item.capabilities) } - #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")] - fn get_server_capabilities(&self) -> Option { - self.server_info().map(|item| item.capabilities) - } - /// Checks if the server has tools available. /// /// This function retrieves the server information and checks if the @@ -156,10 +135,6 @@ pub trait McpClient: Sync + Send { self.server_info() .map(|server_details| server_details.capabilities.logging.is_some()) } - #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")] - fn get_instructions(&self) -> Option { - self.server_info()?.instructions - } fn instructions(&self) -> Option { self.server_info()?.instructions @@ -175,27 +150,15 @@ pub trait McpClient: Sync + Send { request: RequestFromClient, timeout: Option, ) -> SdkResult { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let request_id = sender.next_request_id(); - - let mcp_message = - ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?; - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) + let response = self + .send(MessageFromClient::RequestFromClient(request), None, timeout) .await?; let server_message = response.ok_or_else(|| { RpcError::internal_error() - .with_message("An empty response was received from the server.".to_string()) + .with_message("An empty response was received from the client.".to_string()) })?; - let server_message = server_message.as_single()?; - if server_message.is_error() { return Err(server_message.as_error()?.error.into()); } @@ -205,27 +168,10 @@ pub trait McpClient: Sync + Send { async fn send( &self, - message: ClientMessage, + message: MessageFromClient, + request_id: Option, timeout: Option, - ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Single(message), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_single()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>; async fn send_batch( &self, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index c86a623..2974bfc 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -24,8 +24,11 @@ pub trait McpServerHandler: Send + Sync { client_jsonrpc_request: RequestFromClient, runtime: &dyn McpServer, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpServer) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpServer, + ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, @@ -41,8 +44,11 @@ pub trait McpClientHandler: Send + Sync { server_jsonrpc_request: RequestFromServer, runtime: &dyn McpClient, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpClient) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpClient, + ) -> SdkResult<()>; async fn handle_notification( &self, server_jsonrpc_notification: NotificationFromServer, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index a1d501d..0130c33 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -2,8 +2,8 @@ use std::time::Duration; use crate::schema::{ schema_utils::{ - ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer, - RequestFromServer, ResultFromClient, ServerMessage, + ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, + ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, @@ -29,22 +29,12 @@ pub trait McpServer: Sync + Send { async fn wait_for_initialization(&self); - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> Option { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> &InitializeResult { - self.server_info() - } - async fn send( &self, message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult>; + ) -> SdkResult>; async fn send_batch( &self, @@ -84,13 +74,11 @@ pub trait McpServer: Sync + Send { .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; - let client_messages = response.ok_or_else(|| { + let client_message = response.ok_or_else(|| { RpcError::internal_error() .with_message("An empty response was received from the client.".to_string()) })?; - let client_message = client_messages.as_single()?; - if client_message.is_error() { return Err(client_message.as_error()?.error.into()); } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs new file mode 100644 index 0000000..2372ae9 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs @@ -0,0 +1,101 @@ +use std::sync::atomic::AtomicI64; + +use crate::schema::{schema_utils::McpMessage, RequestId}; +use async_trait::async_trait; + +/// A trait for generating and managing request IDs in a thread-safe manner. +/// +/// Implementors provide functionality to generate unique request IDs, retrieve the last +/// generated ID, and reset the ID counter. +#[async_trait] +pub trait RequestIdGen: Send + Sync { + fn next_request_id(&self) -> RequestId; + #[allow(unused)] + fn last_request_id(&self) -> Option; + #[allow(unused)] + fn reset_to(&self, id: u64); + + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + fn request_id_for_message( + &self, + message: &dyn McpMessage, + request_id: Option, + ) -> Option { + // we need to produce next request_id for requests + if message.is_request() { + // request_id should be None for requests + assert!(request_id.is_none()); + Some(self.next_request_id()) + } else if !message.is_notification() { + // `request_id` must not be `None` for errors, notifications and responses + assert!(request_id.is_some()); + request_id + } else { + None + } + } +} + +pub struct RequestIdGenNumeric { + message_id_counter: AtomicI64, + last_message_id: AtomicI64, +} + +impl RequestIdGenNumeric { + pub fn new(initial_id: Option) -> Self { + Self { + message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), + last_message_id: AtomicI64::new(-1), + } + } +} + +impl RequestIdGen for RequestIdGenNumeric { + /// Generates the next unique request ID as an integer. + /// + /// Increments the internal counter atomically and updates the last generated ID. + /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. + fn next_request_id(&self) -> RequestId { + let id = self + .message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // Store the new ID as the last generated ID + self.last_message_id + .store(id, std::sync::atomic::Ordering::Relaxed); + RequestId::Integer(id) + } + + /// Returns the last generated request ID, if any. + /// + /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). + /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization + /// with other memory operations beyond atomicity. + fn last_request_id(&self) -> Option { + let last_id = self + .last_message_id + .load(std::sync::atomic::Ordering::Relaxed); + if last_id == -1 { + None + } else { + Some(RequestId::Integer(last_id)) + } + } + + /// Resets the internal counter to the specified ID. + /// + /// The provided `id` (u64) is converted to i64 and stored atomically. + fn reset_to(&self, id: u64) { + self.message_id_counter + .store(id as i64, std::sync::atomic::Ordering::Relaxed); + } +} diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 57a3ea8..564db0d 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -128,8 +128,10 @@ use futures::stream::Stream; // stream: &mut impl Stream>, pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), -) -> Option { + event_count: usize, +) -> Option> { let mut buffer = String::new(); + let mut events = vec![]; while let Some(item) = stream.next().await { match item { @@ -158,7 +160,10 @@ pub async fn read_sse_event_from_stream( // Return if data was found if let Some(data) = data { - return Some(data); + events.push(data); + if events.len().eq(&event_count) { + return Some(events); + } } } } @@ -171,9 +176,9 @@ pub async fn read_sse_event_from_stream( None } -pub async fn read_sse_event(response: Response) -> Option { +pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream).await + read_sse_event_from_stream(&mut stream, event_count).await } pub fn test_client_info() -> InitializeRequestParams { diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http.rs index 5eb5e47..23ca27f 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http.rs @@ -169,8 +169,8 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,8 +220,8 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -345,8 +345,8 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .await .unwrap(); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -365,7 +365,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { server.hyper_runtime.await_server().await.unwrap() } -// should establish standalone SSE stream and receive server-initiated messages +// should establish standalone SSE stream and receive server-initiated requests #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); @@ -394,48 +394,59 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { ); let hyper_server = Arc::new(server.hyper_runtime); - let hyper_server_clone = hyper_server.clone(); - let session_id_clone = session_id.to_string(); - - tokio::spawn(async move { - // Send a server-initiated notification that should appear on SSE stream with a valid request_id - hyper_server_clone - .list_roots(&session_id_clone, None) - .await - .unwrap(); - }); - - tokio::time::sleep(Duration::from_millis(2250)).await; - - let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( - RequestId::Integer(0), - ListRootsResult { - meta: None, - roots: vec![], - } - .into(), - ); - send_post_request( - &server.streamable_url, - &serde_json::to_string(&json_rpc_message).unwrap(), - Some(&session_id), - None, - ) - .await - .expect("Request failed"); + // Send two server-initiated request that should appear on SSE stream with a valid request_id + for _ in 0..2 { + let hyper_server_clone = hyper_server.clone(); + let session_id_clone = session_id.to_string(); + tokio::spawn(async move { + hyper_server_clone + .list_roots(&session_id_clone, None) + .await + .unwrap(); + }); + } + + for i in 0..2 { + // send responses back to the server for two server initiated requests + let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( + RequestId::Integer(i), + ListRootsResult { + meta: None, + roots: vec![], + } + .into(), + ); + send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + } - let event = read_sse_event(response).await.unwrap(); + // read two events from the sse stream + let events = read_sse_event(response, 2).await.unwrap(); - let message: ServerJsonrpcRequest = serde_json::from_str(&event).unwrap(); + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); - println!(">>> message {:?} ", message); + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request + else { + panic!("invalid message received!"); + }; - let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message.request + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { panic!("invalid message received!"); }; + // ensure request_ids are unique + assert!(message2.id != message1.id); + hyper_server.graceful_shutdown(ONE_MILLISECOND); } @@ -461,7 +472,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let mut stream = response.bytes_stream(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -490,7 +501,7 @@ async fn should_not_close_get_sse_stream() { .await .unwrap(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( @@ -702,8 +713,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_1.status(), StatusCode::OK); assert_eq!(response_2.status(), StatusCode::OK); - let event = read_sse_event(response_2).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_2, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -718,8 +729,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() "Hello, Ali!" ); - let event = read_sse_event(response_1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_1, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1069,8 +1080,8 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { "text/event-stream" ); - let event = read_sse_event(response).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 2d2a377..08bdc21 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -5,12 +5,7 @@ use crate::{ utils::CancellationToken, IoStream, }; -use std::{ - collections::HashMap, - pin::Pin, - sync::{atomic::AtomicI64, Arc}, - time::Duration, -}; +use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use tokio::task::JoinHandle; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -57,12 +52,7 @@ impl MCPStream { // rpc message stream that receives incoming messages - let sender = MessageDispatcher::new( - pending_requests, - writable, - Arc::new(AtomicI64::new(0)), - request_timeout, - ); + let sender = MessageDispatcher::new(pending_requests, writable, request_timeout); (stream, sender, error_io) } diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 22d0b58..ea1eb04 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -10,7 +10,6 @@ use futures::future::join_all; use std::collections::HashMap; use std::pin::Pin; -use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; @@ -31,7 +30,6 @@ use crate::McpDispatch; pub struct MessageDispatcher { pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, } @@ -49,53 +47,15 @@ impl MessageDispatcher { pub fn new( pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, ) -> Self { Self { pending_requests, writable_std, - message_id_counter, request_timeout, } } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub fn request_id_for_message( - &self, - message: &impl McpMessage, - request_id: Option, - ) -> Option { - // we need to produce next request_id for requests - if message.is_request() { - // request_id should be None for requests - assert!(request_id.is_none()); - Some(self.next_request_id()) - } else if !message.is_notification() { - // `request_id` must not be `None` for errors, notifications and responses - assert!(request_id.is_some()); - request_id - } else { - None - } - } - pub fn next_request_id(&self) -> RequestId { - RequestId::Integer( - self.message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ) - } - async fn store_pending_request( &self, request_id: RequestId, diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-core/src/handler.rs index fcde15e..f0bdefe 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-core/src/handler.rs @@ -98,7 +98,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, + error: &RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs index 53f884c..1c69e8c 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-core-streamable-http/src/handler.rs @@ -103,7 +103,7 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, + error: &RpcError, _: &dyn McpServer, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-core-sse/src/handler.rs index a1a95e4..ab86e9e 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-core-sse/src/handler.rs @@ -41,16 +41,30 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_notification( &self, - _notification: NotificationFromServer, + notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - Err(RpcError::internal_error() - .with_message("handle_notification() Not implemented".to_string())) + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) } async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-core/src/handler.rs index a1a95e4..bd5e4fe 100644 --- a/examples/simple-mcp-client-core/src/handler.rs +++ b/examples/simple-mcp-client-core/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) From f05c2344c8d4f30bc94f156dc42c564886052f6d Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Tue, 19 Aug 2025 21:38:04 -0300 Subject: [PATCH 18/33] chore: release main (#88) * chore: release main * chore: update Cargo.toml for release --------- Co-authored-by: github-actions[bot] --- .release-manifest.json | 20 +++++++++---------- Cargo.lock | 20 +++++++++---------- Cargo.toml | 2 +- crates/rust-mcp-sdk/CHANGELOG.md | 11 ++++++++++ crates/rust-mcp-sdk/Cargo.toml | 2 +- crates/rust-mcp-transport/CHANGELOG.md | 11 ++++++++++ crates/rust-mcp-transport/Cargo.toml | 2 +- .../hello-world-mcp-server-core/Cargo.toml | 2 +- examples/hello-world-mcp-server/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../simple-mcp-client-core-sse/Cargo.toml | 2 +- examples/simple-mcp-client-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- examples/simple-mcp-client/Cargo.toml | 2 +- 15 files changed, 53 insertions(+), 31 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index a7e7c0e..67502c7 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.5.3", + "crates/rust-mcp-sdk": "0.6.0", "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.4.1", - "examples/hello-world-mcp-server": "0.1.27", - "examples/hello-world-mcp-server-core": "0.1.18", - "examples/simple-mcp-client": "0.1.27", - "examples/simple-mcp-client-core": "0.1.27", - "examples/hello-world-server-core-streamable-http": "0.1.18", - "examples/hello-world-server-streamable-http": "0.1.27", - "examples/simple-mcp-client-core-sse": "0.1.18", - "examples/simple-mcp-client-sse": "0.1.18" + "crates/rust-mcp-transport": "0.5.0", + "examples/hello-world-mcp-server": "0.1.28", + "examples/hello-world-mcp-server-core": "0.1.19", + "examples/simple-mcp-client": "0.1.28", + "examples/simple-mcp-client-core": "0.1.28", + "examples/hello-world-server-core-streamable-http": "0.1.19", + "examples/hello-world-server-streamable-http": "0.1.28", + "examples/simple-mcp-client-core-sse": "0.1.19", + "examples/simple-mcp-client-sse": "0.1.19" } diff --git a/Cargo.lock b/Cargo.lock index 9d4b91d..7554175 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -688,7 +688,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server" -version = "0.1.27" +version = "0.1.28" dependencies = [ "async-trait", "futures", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-core" -version = "0.1.18" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -714,7 +714,7 @@ dependencies = [ [[package]] name = "hello-world-server-core-streamable-http" -version = "0.1.18" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.27" +version = "0.1.28" dependencies = [ "async-trait", "futures", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.5.3" +version = "0.6.0" dependencies = [ "async-trait", "axum", @@ -1724,7 +1724,7 @@ dependencies = [ [[package]] name = "rust-mcp-transport" -version = "0.4.1" +version = "0.5.0" dependencies = [ "async-trait", "bytes", @@ -1926,7 +1926,7 @@ dependencies = [ [[package]] name = "simple-mcp-client" -version = "0.1.27" +version = "0.1.28" dependencies = [ "async-trait", "colored", @@ -1940,7 +1940,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core" -version = "0.1.27" +version = "0.1.28" dependencies = [ "async-trait", "colored", @@ -1954,7 +1954,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core-sse" -version = "0.1.18" +version = "0.1.19" dependencies = [ "async-trait", "colored", @@ -1970,7 +1970,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.18" +version = "0.1.19" dependencies = [ "async-trait", "colored", diff --git a/Cargo.toml b/Cargo.toml index a85b5a7..13d723c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ [workspace.dependencies] # Workspace member crates -rust-mcp-transport = { version = "0.4.1", path = "crates/rust-mcp-transport", default-features = false } +rust-mcp-transport = { version = "0.5.0", path = "crates/rust-mcp-transport", default-features = false } rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false } rust-mcp-macros = { version = "0.5.1", path = "crates/rust-mcp-macros", default-features = false } diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index 720c438..bd1c8a8 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [0.6.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.3...rust-mcp-sdk-v0.6.0) (2025-08-19) + + +### ⚠ BREAKING CHANGES + +* improve request ID generation, remove deprecated methods and adding improvements + +### πŸš€ Features + +* Improve request ID generation, remove deprecated methods and adding improvements ([95b91aa](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/95b91aad191e1b8777ca4a02612ab9183e0276d3)) + ## [0.5.3](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.2...rust-mcp-sdk-v0.5.3) (2025-08-19) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 6e05365..d0553c8 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.5.3" +version = "0.6.0" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/crates/rust-mcp-transport/CHANGELOG.md b/crates/rust-mcp-transport/CHANGELOG.md index 1ffd363..bfce3b5 100644 --- a/crates/rust-mcp-transport/CHANGELOG.md +++ b/crates/rust-mcp-transport/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [0.5.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.4.1...rust-mcp-transport-v0.5.0) (2025-08-19) + + +### ⚠ BREAKING CHANGES + +* improve request ID generation, remove deprecated methods and adding improvements + +### πŸš€ Features + +* Improve request ID generation, remove deprecated methods and adding improvements ([95b91aa](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/95b91aad191e1b8777ca4a02612ab9183e0276d3)) + ## [0.4.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.4.0...rust-mcp-transport-v0.4.1) (2025-08-12) diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index 94fd5ba..78c812b 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-transport" -version = "0.4.1" +version = "0.5.0" authors = ["Ali Hashemi"] categories = ["data-structures"] description = "Transport implementations for the MCP (Model Context Protocol) within the rust-mcp-sdk ecosystem, enabling asynchronous data exchange and efficient message handling between MCP clients and servers." diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-core/Cargo.toml index b1256a5..a725e37 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-core" -version = "0.1.18" +version = "0.1.19" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server/Cargo.toml index 0f1b5d1..80faa71 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server" -version = "0.1.27" +version = "0.1.28" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml index afc9c29..234e6fc 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-core-streamable-http" -version = "0.1.18" +version = "0.1.19" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 3abc10d..3a5ffd3 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.27" +version = "0.1.28" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-core-sse/Cargo.toml index d66a7cd..52322a7 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-core-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core-sse" -version = "0.1.18" +version = "0.1.19" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-core/Cargo.toml index 9a9c439..f1b4709 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core" -version = "0.1.27" +version = "0.1.28" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 3b60bc9..1bde25c 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.18" +version = "0.1.19" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client/Cargo.toml index 39c2bc5..5b81f02 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client" -version = "0.1.27" +version = "0.1.28" edition = "2021" publish = false license = "MIT" From f2f0afb542f6ff036a28cf01e102b27ce940665b Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 28 Aug 2025 06:39:19 -0300 Subject: [PATCH 19/33] fix: session ID access in handlers and add helper for listing active (#90) sessions --- crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs | 6 ++++++ .../src/hyper_servers/routes/hyper_utils.rs | 5 +---- crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs | 10 +++++----- crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 7 +++++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 109dde5..85cf791 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -70,6 +70,12 @@ impl HyperRuntime { result.map_err(|err| err.into()) } + /// Returns a list of active session IDs from the session store. + pub async fn sessions(&self) -> Vec { + self.state.session_store.keys().await + } + + /// Retrieves the runtime associated with the given session ID from the session store. pub async fn runtime_by_session( &self, session_id: &SessionId, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 79bf226..0a77913 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -172,10 +172,7 @@ pub async fn start_new_session( session_id.to_owned(), )); - tracing::info!( - "a new client joined : {}", - runtime.session_id().await.unwrap_or_default().to_owned() - ); + tracing::info!("a new client joined : {}", &session_id); let response = create_sse_stream( runtime.clone(), diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 49b5c3c..44f3e53 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -197,6 +197,11 @@ impl McpServer for ServerRuntime { } Ok(()) } + + #[cfg(feature = "hyper-server")] + fn session_id(&self) -> Option { + self.session_id.to_owned() + } } impl ServerRuntime { @@ -435,11 +440,6 @@ impl ServerRuntime { } } - #[cfg(feature = "hyper-server")] - pub(crate) async fn session_id(&self) -> Option { - self.session_id.to_owned() - } - #[cfg(feature = "hyper-server")] pub(crate) fn new_instance( server_details: Arc, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 0130c33..2eab9db 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use crate::schema::{ schema_utils::{ ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, @@ -16,6 +14,8 @@ use crate::schema::{ SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; use async_trait::async_trait; +use rust_mcp_transport::SessionId; +use std::time::Duration; use crate::{error::SdkResult, utils::format_assertion_message}; @@ -405,4 +405,7 @@ pub trait McpServer: Sync + Send { } Ok(()) } + + #[cfg(feature = "hyper-server")] + fn session_id(&self) -> Option; } From a219c260a30806befe8628e691b51adaed5b0693 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 28 Aug 2025 06:51:18 -0300 Subject: [PATCH 20/33] chore: release main (#91) --- .release-manifest.json | 18 +++++++++--------- Cargo.lock | 18 +++++++++--------- crates/rust-mcp-sdk/CHANGELOG.md | 7 +++++++ crates/rust-mcp-sdk/Cargo.toml | 2 +- .../hello-world-mcp-server-core/Cargo.toml | 2 +- examples/hello-world-mcp-server/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- examples/simple-mcp-client-core-sse/Cargo.toml | 2 +- examples/simple-mcp-client-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- examples/simple-mcp-client/Cargo.toml | 2 +- 12 files changed, 34 insertions(+), 27 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index 67502c7..6bac6a9 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.6.0", + "crates/rust-mcp-sdk": "0.6.1", "crates/rust-mcp-macros": "0.5.1", "crates/rust-mcp-transport": "0.5.0", - "examples/hello-world-mcp-server": "0.1.28", - "examples/hello-world-mcp-server-core": "0.1.19", - "examples/simple-mcp-client": "0.1.28", - "examples/simple-mcp-client-core": "0.1.28", - "examples/hello-world-server-core-streamable-http": "0.1.19", - "examples/hello-world-server-streamable-http": "0.1.28", - "examples/simple-mcp-client-core-sse": "0.1.19", - "examples/simple-mcp-client-sse": "0.1.19" + "examples/hello-world-mcp-server": "0.1.29", + "examples/hello-world-mcp-server-core": "0.1.20", + "examples/simple-mcp-client": "0.1.29", + "examples/simple-mcp-client-core": "0.1.29", + "examples/hello-world-server-core-streamable-http": "0.1.20", + "examples/hello-world-server-streamable-http": "0.1.29", + "examples/simple-mcp-client-core-sse": "0.1.20", + "examples/simple-mcp-client-sse": "0.1.20" } diff --git a/Cargo.lock b/Cargo.lock index 7554175..061edf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -688,7 +688,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "futures", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-core" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "futures", @@ -714,7 +714,7 @@ dependencies = [ [[package]] name = "hello-world-server-core-streamable-http" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "futures", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "futures", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.6.0" +version = "0.6.1" dependencies = [ "async-trait", "axum", @@ -1926,7 +1926,7 @@ dependencies = [ [[package]] name = "simple-mcp-client" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "colored", @@ -1940,7 +1940,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "colored", @@ -1954,7 +1954,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core-sse" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "colored", @@ -1970,7 +1970,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "colored", diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index bd1c8a8..057dffd 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.6.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.0...rust-mcp-sdk-v0.6.1) (2025-08-28) + + +### πŸ› Bug Fixes + +* Session ID access in handlers and add helper for listing active ([#90](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/90)) ([f2f0afb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/f2f0afb542f6ff036a28cf01e102b27ce940665b)) + ## [0.6.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.3...rust-mcp-sdk-v0.6.0) (2025-08-19) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index d0553c8..161d813 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.6.0" +version = "0.6.1" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-core/Cargo.toml index a725e37..c28b8c3 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-core" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server/Cargo.toml index 80faa71..cd8f63d 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml index 234e6fc..7ae24d4 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-core-streamable-http" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 3a5ffd3..3e763c1 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-core-sse/Cargo.toml index 52322a7..704ae28 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-core-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core-sse" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-core/Cargo.toml index f1b4709..84552a1 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 1bde25c..9782db9 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client/Cargo.toml index 5b81f02..bae4943 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" From 54cc8edb55c41455dd9211f296560e7a792a7b9c Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Sat, 30 Aug 2025 11:52:08 -0300 Subject: [PATCH 21/33] fix: tool-box macro panic on invalid requests (#92) --- Cargo.lock | 185 +++++++----------- .../rust-mcp-sdk/src/mcp_macros/tool_box.rs | 16 +- 2 files changed, 83 insertions(+), 118 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 061edf9..371a94a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -239,9 +239,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.2" +version = "2.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a65b545ab31d687cff52899d4890855fec459eb6afe0da6417b8a18da87aa29" +checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" [[package]] name = "bumpalo" @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.33" +version = "1.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f" +checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" dependencies = [ "jobserver", "libc", @@ -459,9 +459,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] @@ -626,7 +626,7 @@ dependencies = [ "js-sys", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasi 0.14.3+wasi-0.2.4", "wasm-bindgen", ] @@ -1020,9 +1020,9 @@ dependencies = [ [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -1041,9 +1041,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" +checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" dependencies = [ "equivalent", "hashbrown", @@ -1066,9 +1066,9 @@ dependencies = [ [[package]] name = "io-uring" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" dependencies = [ "bitflags", "cfg-if", @@ -1108,9 +1108,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ "getrandom 0.3.3", "libc", @@ -1196,11 +1196,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -1269,12 +1269,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -1308,12 +1307,6 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking" version = "2.2.1" @@ -1345,9 +1338,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pin-project-lite" @@ -1363,9 +1356,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "potential_utf" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" dependencies = [ "zerovec", ] @@ -1422,9 +1415,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", "cfg_aliases", @@ -1433,8 +1426,8 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls", - "socket2 0.5.10", - "thiserror 2.0.15", + "socket2 0.6.0", + "thiserror 2.0.16", "tokio", "tracing", "web-time", @@ -1442,9 +1435,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.12" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes", "getrandom 0.3.3", @@ -1455,7 +1448,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.15", + "thiserror 2.0.16", "tinyvec", "tracing", "web-time", @@ -1463,16 +1456,16 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.0", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -1571,47 +1564,32 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] name = "regex-syntax" -version = "0.6.29" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - -[[package]] -name = "regex-syntax" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" [[package]] name = "reqwest" @@ -1691,9 +1669,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0e71aee61257cd3d4a78fdc10c92c29e7a55c4f767119ffdafd837bb5e5cb9a" +checksum = "098436b06bfa4b88b110d12a5567cf37fd454735ee67cab7eb48bdbea0dd0e57" dependencies = [ "serde", "serde_json", @@ -1714,7 +1692,7 @@ dependencies = [ "rust-mcp-transport", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.16", "tokio", "tokio-stream", "tracing", @@ -1733,7 +1711,7 @@ dependencies = [ "rust-mcp-schema", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.16", "tokio", "tokio-stream", "tracing", @@ -1934,7 +1912,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.16", "tokio", ] @@ -1948,7 +1926,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.16", "tokio", ] @@ -1962,7 +1940,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.16", "tokio", "tracing", "tracing-subscriber", @@ -1978,7 +1956,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.15", + "thiserror 2.0.16", "tokio", "tracing", "tracing-subscriber", @@ -2070,11 +2048,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.15" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" dependencies = [ - "thiserror-impl 2.0.15", + "thiserror-impl 2.0.16", ] [[package]] @@ -2090,9 +2068,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.15" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" dependencies = [ "proc-macro2", "quote", @@ -2321,14 +2299,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -2363,9 +2341,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.4" +version = "2.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", "idna", @@ -2431,11 +2409,11 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.2+wasi-0.2.4" +version = "0.14.3+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "6a51ae83037bdd272a9e28ce236db8c07016dd0d50c27038b3f407533c030c95" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] @@ -2563,28 +2541,6 @@ dependencies = [ "rustix", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-link" version = "0.1.3" @@ -2770,13 +2726,10 @@ dependencies = [ ] [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags", -] +checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814" [[package]] name = "writeable" diff --git a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs index a5b75d5..3edb344 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs @@ -67,8 +67,20 @@ macro_rules! tool_box { /// Attempts to convert a tool request into the appropriate tool variant fn try_from(value: rust_mcp_sdk::schema::CallToolRequestParams) -> Result { - let v = serde_json::to_value(value.arguments.unwrap()) - .map_err(rust_mcp_sdk::schema::schema_utils::CallToolError::new)?; + let arguments = value + .arguments + .ok_or(rust_mcp_sdk::schema::schema_utils::CallToolError::invalid_arguments( + &value.name, + Some("Missing 'arguments' field in the request".to_string()) + ))?; + + let v = serde_json::to_value(arguments).map_err(|err| { + rust_mcp_sdk::schema::schema_utils::CallToolError::invalid_arguments( + &value.name, + Some(format!("{err}")), + ) + })?; + match value.name { $( name if name == $tool::tool_name().as_str() => { From 9770abc8f3818ad4039d580a3c05194942f14e18 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Sat, 30 Aug 2025 12:01:12 -0300 Subject: [PATCH 22/33] chore: release main (#93) --- .release-manifest.json | 18 +++++++++--------- Cargo.lock | 18 +++++++++--------- crates/rust-mcp-sdk/CHANGELOG.md | 7 +++++++ crates/rust-mcp-sdk/Cargo.toml | 2 +- .../hello-world-mcp-server-core/Cargo.toml | 2 +- examples/hello-world-mcp-server/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- examples/simple-mcp-client-core-sse/Cargo.toml | 2 +- examples/simple-mcp-client-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- examples/simple-mcp-client/Cargo.toml | 2 +- 12 files changed, 34 insertions(+), 27 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index 6bac6a9..716d65a 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.6.1", + "crates/rust-mcp-sdk": "0.6.2", "crates/rust-mcp-macros": "0.5.1", "crates/rust-mcp-transport": "0.5.0", - "examples/hello-world-mcp-server": "0.1.29", - "examples/hello-world-mcp-server-core": "0.1.20", - "examples/simple-mcp-client": "0.1.29", - "examples/simple-mcp-client-core": "0.1.29", - "examples/hello-world-server-core-streamable-http": "0.1.20", - "examples/hello-world-server-streamable-http": "0.1.29", - "examples/simple-mcp-client-core-sse": "0.1.20", - "examples/simple-mcp-client-sse": "0.1.20" + "examples/hello-world-mcp-server": "0.1.30", + "examples/hello-world-mcp-server-core": "0.1.21", + "examples/simple-mcp-client": "0.1.30", + "examples/simple-mcp-client-core": "0.1.30", + "examples/hello-world-server-core-streamable-http": "0.1.21", + "examples/hello-world-server-streamable-http": "0.1.30", + "examples/simple-mcp-client-core-sse": "0.1.21", + "examples/simple-mcp-client-sse": "0.1.21" } diff --git a/Cargo.lock b/Cargo.lock index 371a94a..6f86732 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -688,7 +688,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server" -version = "0.1.29" +version = "0.1.30" dependencies = [ "async-trait", "futures", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-core" -version = "0.1.20" +version = "0.1.21" dependencies = [ "async-trait", "futures", @@ -714,7 +714,7 @@ dependencies = [ [[package]] name = "hello-world-server-core-streamable-http" -version = "0.1.20" +version = "0.1.21" dependencies = [ "async-trait", "futures", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.29" +version = "0.1.30" dependencies = [ "async-trait", "futures", @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.6.1" +version = "0.6.2" dependencies = [ "async-trait", "axum", @@ -1904,7 +1904,7 @@ dependencies = [ [[package]] name = "simple-mcp-client" -version = "0.1.29" +version = "0.1.30" dependencies = [ "async-trait", "colored", @@ -1918,7 +1918,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core" -version = "0.1.29" +version = "0.1.30" dependencies = [ "async-trait", "colored", @@ -1932,7 +1932,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core-sse" -version = "0.1.20" +version = "0.1.21" dependencies = [ "async-trait", "colored", @@ -1948,7 +1948,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.20" +version = "0.1.21" dependencies = [ "async-trait", "colored", diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index 057dffd..f3ffaf4 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.6.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.1...rust-mcp-sdk-v0.6.2) (2025-08-30) + + +### πŸ› Bug Fixes + +* Tool-box macro panic on invalid requests ([#92](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/92)) ([54cc8ed](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/54cc8edb55c41455dd9211f296560e7a792a7b9c)) + ## [0.6.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.0...rust-mcp-sdk-v0.6.1) (2025-08-28) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 161d813..9261df7 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.6.1" +version = "0.6.2" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-core/Cargo.toml index c28b8c3..1a9a684 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-core" -version = "0.1.20" +version = "0.1.21" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server/Cargo.toml index cd8f63d..73a6585 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server" -version = "0.1.29" +version = "0.1.30" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml index 7ae24d4..08bd089 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-core-streamable-http" -version = "0.1.20" +version = "0.1.21" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 3e763c1..fd0cc60 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.29" +version = "0.1.30" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-core-sse/Cargo.toml index 704ae28..fdf119e 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-core-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core-sse" -version = "0.1.20" +version = "0.1.21" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-core/Cargo.toml index 84552a1..6fa16a2 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core" -version = "0.1.29" +version = "0.1.30" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 9782db9..e529bb2 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.20" +version = "0.1.21" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client/Cargo.toml index bae4943..d524259 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client" -version = "0.1.29" +version = "0.1.30" edition = "2021" publish = false license = "MIT" From 9d8c1fbdf3ddb7c67ce1fb7dcb8e50b8ba2e1202 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Sun, 31 Aug 2025 20:13:34 -0300 Subject: [PATCH 23/33] fix: correct pending_requests instance (#94) --- crates/rust-mcp-transport/src/stdio.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 06931d2..582af5d 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -237,13 +237,11 @@ where Ok(stream) } else { - let pending_requests: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); let (stream, sender, error_stream) = MCPStream::create( Box::pin(tokio::io::stdin()), Mutex::new(Box::pin(tokio::io::stdout())), IoStream::Writable(Box::pin(tokio::io::stderr())), - pending_requests, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); From 3508e1e619bfd448bd02b3f7266ffd7d17c61f4e Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Sun, 31 Aug 2025 20:23:28 -0300 Subject: [PATCH 24/33] chore: release main (#95) * chore: release main * chore: update Cargo.toml for release --------- Co-authored-by: github-actions[bot] --- .release-manifest.json | 20 +++++++++---------- Cargo.lock | 20 +++++++++---------- Cargo.toml | 2 +- crates/rust-mcp-sdk/CHANGELOG.md | 2 ++ crates/rust-mcp-sdk/Cargo.toml | 2 +- crates/rust-mcp-transport/CHANGELOG.md | 7 +++++++ crates/rust-mcp-transport/Cargo.toml | 2 +- .../hello-world-mcp-server-core/Cargo.toml | 2 +- examples/hello-world-mcp-server/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../simple-mcp-client-core-sse/Cargo.toml | 2 +- examples/simple-mcp-client-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- examples/simple-mcp-client/Cargo.toml | 2 +- 15 files changed, 40 insertions(+), 31 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index 716d65a..97a0f63 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.6.2", + "crates/rust-mcp-sdk": "0.6.3", "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.5.0", - "examples/hello-world-mcp-server": "0.1.30", - "examples/hello-world-mcp-server-core": "0.1.21", - "examples/simple-mcp-client": "0.1.30", - "examples/simple-mcp-client-core": "0.1.30", - "examples/hello-world-server-core-streamable-http": "0.1.21", - "examples/hello-world-server-streamable-http": "0.1.30", - "examples/simple-mcp-client-core-sse": "0.1.21", - "examples/simple-mcp-client-sse": "0.1.21" + "crates/rust-mcp-transport": "0.5.1", + "examples/hello-world-mcp-server": "0.1.31", + "examples/hello-world-mcp-server-core": "0.1.22", + "examples/simple-mcp-client": "0.1.31", + "examples/simple-mcp-client-core": "0.1.31", + "examples/hello-world-server-core-streamable-http": "0.1.22", + "examples/hello-world-server-streamable-http": "0.1.31", + "examples/simple-mcp-client-core-sse": "0.1.22", + "examples/simple-mcp-client-sse": "0.1.22" } diff --git a/Cargo.lock b/Cargo.lock index 6f86732..c10e354 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -688,7 +688,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server" -version = "0.1.30" +version = "0.1.31" dependencies = [ "async-trait", "futures", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-core" -version = "0.1.21" +version = "0.1.22" dependencies = [ "async-trait", "futures", @@ -714,7 +714,7 @@ dependencies = [ [[package]] name = "hello-world-server-core-streamable-http" -version = "0.1.21" +version = "0.1.22" dependencies = [ "async-trait", "futures", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.30" +version = "0.1.31" dependencies = [ "async-trait", "futures", @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.6.2" +version = "0.6.3" dependencies = [ "async-trait", "axum", @@ -1702,7 +1702,7 @@ dependencies = [ [[package]] name = "rust-mcp-transport" -version = "0.5.0" +version = "0.5.1" dependencies = [ "async-trait", "bytes", @@ -1904,7 +1904,7 @@ dependencies = [ [[package]] name = "simple-mcp-client" -version = "0.1.30" +version = "0.1.31" dependencies = [ "async-trait", "colored", @@ -1918,7 +1918,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core" -version = "0.1.30" +version = "0.1.31" dependencies = [ "async-trait", "colored", @@ -1932,7 +1932,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-core-sse" -version = "0.1.21" +version = "0.1.22" dependencies = [ "async-trait", "colored", @@ -1948,7 +1948,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.21" +version = "0.1.22" dependencies = [ "async-trait", "colored", diff --git a/Cargo.toml b/Cargo.toml index 13d723c..b4f7cca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ [workspace.dependencies] # Workspace member crates -rust-mcp-transport = { version = "0.5.0", path = "crates/rust-mcp-transport", default-features = false } +rust-mcp-transport = { version = "0.5.1", path = "crates/rust-mcp-transport", default-features = false } rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false } rust-mcp-macros = { version = "0.5.1", path = "crates/rust-mcp-macros", default-features = false } diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index f3ffaf4..db5a72b 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,7 @@ # Changelog +## [0.6.3](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.2...rust-mcp-sdk-v0.6.3) (2025-08-31) + ## [0.6.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.1...rust-mcp-sdk-v0.6.2) (2025-08-30) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 9261df7..48ea665 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.6.2" +version = "0.6.3" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/crates/rust-mcp-transport/CHANGELOG.md b/crates/rust-mcp-transport/CHANGELOG.md index bfce3b5..9a0d2e1 100644 --- a/crates/rust-mcp-transport/CHANGELOG.md +++ b/crates/rust-mcp-transport/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.5.0...rust-mcp-transport-v0.5.1) (2025-08-31) + + +### πŸ› Bug Fixes + +* Correct pending_requests instance ([#94](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/94)) ([9d8c1fb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/9d8c1fbdf3ddb7c67ce1fb7dcb8e50b8ba2e1202)) + ## [0.5.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.4.1...rust-mcp-transport-v0.5.0) (2025-08-19) diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index 78c812b..ec061bb 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-transport" -version = "0.5.0" +version = "0.5.1" authors = ["Ali Hashemi"] categories = ["data-structures"] description = "Transport implementations for the MCP (Model Context Protocol) within the rust-mcp-sdk ecosystem, enabling asynchronous data exchange and efficient message handling between MCP clients and servers." diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-core/Cargo.toml index 1a9a684..bbab301 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-core" -version = "0.1.21" +version = "0.1.22" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server/Cargo.toml index 73a6585..63a54af 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server" -version = "0.1.30" +version = "0.1.31" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml index 08bd089..99d1011 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-core-streamable-http" -version = "0.1.21" +version = "0.1.22" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index fd0cc60..df4296d 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.30" +version = "0.1.31" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-core-sse/Cargo.toml index fdf119e..0e32790 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-core-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core-sse" -version = "0.1.21" +version = "0.1.22" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-core/Cargo.toml index 6fa16a2..0dacc2d 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-core" -version = "0.1.30" +version = "0.1.31" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index e529bb2..14fd96b 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.21" +version = "0.1.22" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client/Cargo.toml index d524259..9599c46 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client" -version = "0.1.30" +version = "0.1.31" edition = "2021" publish = false license = "MIT" From 5dacceb0c2d18b8334744a13d438c6916bb7244c Mon Sep 17 00:00:00 2001 From: Mark A Date: Wed, 17 Sep 2025 23:59:37 +0100 Subject: [PATCH 25/33] feat: add tls-no-provider feature (#97) * Add tls-no-provider feature * Update README.md file. --- README.md | 1 + crates/rust-mcp-sdk/Cargo.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 1581d1d..507b55e 100644 --- a/README.md +++ b/README.md @@ -419,6 +419,7 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. - `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 48ea665..50d5e7f 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -67,6 +67,7 @@ hyper-server = [ "rust-mcp-transport/sse", ] ssl = ["axum-server/tls-rustls"] +tls-no-provider = ["axum-server/tls-rustls-no-provider"] macros = ["rust-mcp-macros/sdk"] # enables mcp protocol version 2025_06_18 From a2d6d23ab59fbc34d04526e2606f747f93a8468c Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Wed, 17 Sep 2025 20:22:40 -0300 Subject: [PATCH 26/33] feat!: update ServerHandler and ServerHandlerCore traits (#96) * feat: update server handler traits to accept Arc * feat: remove on_server_started in favor of on_initialized * Unblock Stream Loop in start Function with Task Spawning * chore: update start_stresam * chore: cleanup --- README.md | 4 +- crates/rust-mcp-sdk/README.md | 4 +- .../src/hyper_servers/routes/hyper_utils.rs | 4 +- .../src/hyper_servers/routes/sse_routes.rs | 4 +- .../src/mcp_handlers/mcp_server_handler.rs | 51 ++-- .../mcp_handlers/mcp_server_handler_core.rs | 17 +- .../src/mcp_runtimes/server_runtime.rs | 219 +++++++++++++----- .../server_runtime/mcp_server_runtime.rs | 19 +- .../server_runtime/mcp_server_runtime_core.rs | 13 +- .../src/mcp_traits/mcp_handler.rs | 11 +- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 7 +- .../rust-mcp-sdk/tests/common/test_server.rs | 12 +- .../tests/test_protocol_compatibility.rs | 2 +- crates/rust-mcp-transport/src/stdio.rs | 7 +- doc/getting-started-mcp-server.md | 4 +- .../src/handler.rs | 8 +- .../hello-world-mcp-server/src/handler.rs | 5 +- examples/hello-world-mcp-server/src/main.rs | 5 +- .../src/handler.rs | 8 +- .../src/handler.rs | 8 +- 20 files changed, 247 insertions(+), 165 deletions(-) diff --git a/README.md b/README.md index 507b55e..1d334d6 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +191,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 1581d1d..cbe7318 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -180,7 +180,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +191,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 0a77913..daf5d94 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -166,11 +166,11 @@ pub async fn start_new_session( let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let runtime: Arc = Arc::new(server_runtime::create_server_instance( + let runtime: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); tracing::info!("a new client joined : {}", &session_id); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index e1c00f8..a014e94 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -99,11 +99,11 @@ pub async fn handle_sse( .unwrap(); let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let server: Arc = Arc::new(server_runtime::create_server_instance( + let server: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); state .session_store diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 89aebf5..9b9577e 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,6 +1,7 @@ use crate::schema::{schema_utils::CallToolError, *}; use async_trait::async_trait; use serde_json::Value; +use std::sync::Arc; use crate::{mcp_traits::mcp_server::McpServer, utils::enforce_compatible_protocol_version}; @@ -15,7 +16,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, runtime: &dyn McpServer) {} + async fn on_initialized(&self, runtime: Arc) {} /// Handles the InitializeRequest from a client. /// @@ -29,7 +30,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialize_request( &self, initialize_request: InitializeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. @@ -65,7 +66,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_ping_request( &self, _: PingRequest, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result { Ok(Result::default()) } @@ -77,7 +78,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resources_request( &self, request: ListResourcesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -93,7 +94,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resource_templates_request( &self, request: ListResourceTemplatesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -109,7 +110,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_read_resource_request( &self, request: ReadResourceRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -125,7 +126,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_subscribe_request( &self, request: SubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -141,7 +142,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_unsubscribe_request( &self, request: UnsubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -157,7 +158,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_prompts_request( &self, request: ListPromptsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -173,7 +174,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_get_prompt_request( &self, request: GetPromptRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -189,7 +190,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -205,7 +206,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -220,7 +221,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_set_level_request( &self, request: SetLevelRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -236,7 +237,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_complete_request( &self, request: CompleteRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -252,7 +253,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_custom_request( &self, request: Value, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())) @@ -265,7 +266,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialized_notification( &self, notification: InitializedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -275,7 +276,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_cancelled_notification( &self, notification: CancelledNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -285,7 +286,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_progress_notification( &self, notification: ProgressNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -295,7 +296,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_roots_list_changed_notification( &self, notification: RootsListChangedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -320,18 +321,8 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } - - /// Called when the server has successfully started. - /// - /// Sends a "Server started successfully" message to stderr. - /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index e7b0e6d..9275da7 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,8 +1,8 @@ +use crate::mcp_traits::mcp_server::McpServer; use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; - -use crate::mcp_traits::mcp_server::McpServer; +use std::sync::Arc; /// Defines the `ServerHandlerCore` trait for handling Model Context Protocol (MCP) server operations. /// Unlike `ServerHandler`, this trait offers no default implementations, providing full control over MCP message handling @@ -14,7 +14,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, _runtime: &dyn McpServer) {} + async fn on_initialized(&self, _runtime: Arc) {} /// Asynchronously handles an incoming request from the client. /// @@ -26,7 +26,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; /// Asynchronously handles an incoming notification from the client. @@ -36,7 +36,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_notification( &self, notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; /// Asynchronously handles an error received from the client. @@ -46,11 +46,6 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 44f3e53..57ba260 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -22,9 +22,10 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::{oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; +const TASK_CHANNEL_CAPACITY: usize = 500; // Define a type alias for the TransportDispatcher trait object type TransportType = Arc< @@ -55,8 +56,6 @@ pub struct ServerRuntime { impl McpServer for ServerRuntime { /// Set the client details, storing them in client_details async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { - self.handler.on_server_started(self).await; - self.client_details_tx .send(Some(client_details)) .map_err(|_| { @@ -132,8 +131,9 @@ impl McpServer for ServerRuntime { } /// Main runtime loop, processes incoming messages and handles requests - async fn start(&self) -> SdkResult<()> { - let transport_map = self.transport_map.read().await; + async fn start(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -142,43 +142,88 @@ impl McpServer for ServerRuntime { let mut stream = transport.start().await?; + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + // Process incoming messages from the client while let Some(mcp_messages) = stream.next().await { match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, transport).await; - - match result { - Ok(result) => { - if let Some(result) = result { - transport - .send_message(ServerMessages::Single(result), None) - .await?; + let transport = transport.clone(); + let self = self.clone(); + let tx = tx.clone(); + + // Handle incoming messages in a separate task to avoid blocking the stream. + tokio::spawn(async move { + let result = self.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + // Send result to the main loop + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send result to channel: {}", error); } - Err(error) => { - tracing::error!("Error handling message : {}", error) - } - } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); + let transport = transport.clone(); + let self = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport + .send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => Err(error), + }; - if !results.is_empty() { - transport - .send_message(ServerMessages::Batch(results), None) - .await?; - } + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } } + + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } + return Ok(()); } @@ -223,7 +268,7 @@ impl ServerRuntime { } pub(crate) async fn handle_message( - &self, + self: &Arc, message: ClientMessage, transport: &Arc< dyn TransportDispatcher< @@ -240,7 +285,7 @@ impl ServerRuntime { ClientMessage::Request(client_jsonrpc_request) => { let result = self .handler - .handle_request(client_jsonrpc_request.request, self) + .handle_request(client_jsonrpc_request.request, self.clone()) .await; // create a response to send back to the client let response: MessageFromServer = match result { @@ -262,13 +307,13 @@ impl ServerRuntime { } ClientMessage::Notification(client_jsonrpc_notification) => { self.handler - .handle_notification(client_jsonrpc_notification.notification, self) + .handle_notification(client_jsonrpc_notification.notification, self.clone()) .await?; None } ClientMessage::Error(jsonrpc_error) => { self.handler - .handle_error(&jsonrpc_error.error, self) + .handle_error(&jsonrpc_error.error, self.clone()) .await?; if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { tx_response @@ -282,7 +327,6 @@ impl ServerRuntime { } None } - // The response is the result of a request, it is processed at the transport level. ClientMessage::Response(response) => { if let Some(tx_response) = transport.pending_request_tx(&response.id).await { tx_response @@ -379,7 +423,8 @@ impl ServerRuntime { self.store_transport(stream_id, Arc::new(transport)).await?; - let transport = self.transport_by_stream(stream_id).await?; + let self_clone = self.clone(); + let transport = self_clone.transport_by_stream(stream_id).await?; let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); let abort_alive_task = transport @@ -397,40 +442,96 @@ impl ServerRuntime { transport.consume_string_payload(&payload).await?; } + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + loop { tokio::select! { Some(mcp_messages) = stream.next() =>{ match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, &transport).await?; - if let Some(result) = result { - transport.send_message(ServerMessages::Single(result), None).await?; - } + let transport = transport.clone(); + let self_clone = self.clone(); + let tx = tx.clone(); + tokio::spawn(async move { + + let result = self_clone.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, &transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); - - - if !results.is_empty() { - transport.send_message(ServerMessages::Batch(results), None).await?; - } + let transport = transport.clone(); + let self_clone = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self_clone.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport.send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + }else { + Ok(None) + } + }, + Err(error) => Err(error), + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } + // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } return Ok(()); } } _ = &mut disconnect_rx => { + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } self.remove_transport(stream_id).await?; // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -445,10 +546,10 @@ impl ServerRuntime { server_details: Arc, handler: Arc, session_id: SessionId, - ) -> Self { + ) -> Arc { let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details, handler, session_id: Some(session_id), @@ -456,7 +557,7 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } pub(crate) fn new( @@ -469,12 +570,12 @@ impl ServerRuntime { ServerMessage, >, handler: Arc, - ) -> Self { + ) -> Arc { let mut map: HashMap = HashMap::new(); map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details: Arc::new(server_details), handler, #[cfg(feature = "hyper-server")] @@ -483,6 +584,6 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index ea19e19..5fbc43c 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -49,7 +49,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandler, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -62,7 +62,7 @@ pub(crate) fn create_server_instance( server_details: Arc, handler: Arc, session_id: SessionId, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new_instance(server_details, handler, session_id) } @@ -80,7 +80,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { match client_jsonrpc_request { schema_utils::RequestFromClient::ClientRequest(client_request) => { @@ -178,7 +178,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -187,7 +187,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { match client_jsonrpc_notification { schema_utils::NotificationFromClient::ClientNotification(client_notification) => { @@ -199,7 +199,10 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } ClientNotification::InitializedNotification(initialized_notification) => { self.handler - .handle_initialized_notification(initialized_notification, runtime) + .handle_initialized_notification( + initialized_notification, + runtime.clone(), + ) .await?; self.handler.on_initialized(runtime).await; } @@ -226,8 +229,4 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } Ok(()) } - - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index e0e7108..5ed2239 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -43,7 +43,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandlerCore, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -66,7 +66,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // store the client details if the request is a client initialization request if let schema_utils::RequestFromClient::ClientRequest(ClientRequest::InitializeRequest( @@ -88,7 +88,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -96,11 +96,11 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { // Trigger the `on_initialized()` callback if an `initialized_notification` is received from the client. if client_jsonrpc_notification.is_initialized_notification() { - self.handler.on_initialized(runtime).await; + self.handler.on_initialized(runtime.clone()).await; } // handle notification @@ -109,7 +109,4 @@ impl McpServerHandler for RuntimeCoreInternalHandler> .await?; Ok(()) } - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index 2974bfc..cb37f2a 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -6,9 +6,9 @@ use crate::schema::schema_utils::{NotificationFromClient, RequestFromClient, Res #[cfg(feature = "client")] use crate::schema::schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}; -use crate::schema::RpcError; - use crate::error::SdkResult; +use crate::schema::RpcError; +use std::sync::Arc; #[cfg(feature = "client")] use super::mcp_client::McpClient; @@ -18,21 +18,20 @@ use super::mcp_server::McpServer; #[cfg(feature = "server")] #[async_trait] pub trait McpServerHandler: Send + Sync { - async fn on_server_started(&self, runtime: &dyn McpServer); async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 2eab9db..dc860b6 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -13,16 +13,15 @@ use crate::schema::{ ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; +use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; use rust_mcp_transport::SessionId; -use std::time::Duration; - -use crate::{error::SdkResult, utils::format_assertion_message}; +use std::{sync::Arc, time::Duration}; //TODO: support options , such as enforceStrictCapabilities #[async_trait] pub trait McpServer: Sync + Send { - async fn start(&self) -> SdkResult<()>; + async fn start(self: Arc) -> SdkResult<()>; async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index aa8e2fb..176e0d2 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -17,7 +17,7 @@ pub mod test_server_common { mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, McpServer, SessionId, }; - use std::sync::RwLock; + use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::timeout; @@ -71,16 +71,10 @@ pub mod test_server_common { #[async_trait] impl ServerHandler for TestServerHandler { - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } - async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; @@ -94,7 +88,7 @@ pub mod test_server_common { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) diff --git a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs index 5c184cf..9f2fd95 100644 --- a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs +++ b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs @@ -30,7 +30,7 @@ mod protocol_compatibility_on_server { ); handler - .handle_initialize_request(InitializeRequest::new(initialize_request), &runtime) + .handle_initialize_request(InitializeRequest::new(initialize_request), runtime) .await } diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 582af5d..0b67d64 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -210,13 +210,12 @@ where .take() .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; - let pending_requests_clone1 = self.pending_requests.clone(); - let pending_requests_clone2 = self.pending_requests.clone(); + let pending_requests_clone = self.pending_requests.clone(); tokio::spawn(async move { let _ = process.wait().await; // clean up pending requests to cancel waiting tasks - let mut pending_requests = pending_requests_clone1.lock().await; + let mut pending_requests = pending_requests_clone.lock().await; pending_requests.clear(); }); @@ -224,7 +223,7 @@ where Box::pin(stdout), Mutex::new(Box::pin(stdin)), IoStream::Readable(Box::pin(stderr)), - pending_requests_clone2, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 358b1b4..6fac258 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -160,7 +160,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, _request: ListToolsRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -173,7 +173,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-core/src/handler.rs index f0bdefe..acf55ea 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -90,7 +92,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -99,7 +101,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-mcp-server/src/handler.rs b/examples/hello-world-mcp-server/src/handler.rs index d9741a0..47925a0 100644 --- a/examples/hello-world-mcp-server/src/handler.rs +++ b/examples/hello-world-mcp-server/src/handler.rs @@ -4,6 +4,7 @@ use rust_mcp_sdk::schema::{ ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use std::sync::Arc; use crate::tools::GreetingTools; @@ -20,7 +21,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +34,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server/src/main.rs index 00ca6a7..98ff6f0 100644 --- a/examples/hello-world-mcp-server/src/main.rs +++ b/examples/hello-world-mcp-server/src/main.rs @@ -1,6 +1,8 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, @@ -40,7 +42,8 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server: Arc = + server_runtime::create_server(server_details, transport, handler); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs index 1c69e8c..7941075 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-core-streamable-http/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -95,7 +97,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -104,7 +106,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index b8ce355..c4732d2 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, @@ -20,7 +22,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +35,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = @@ -45,6 +47,4 @@ impl ServerHandler for MyServerHandler { GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool.call_tool(), } } - - async fn on_server_started(&self, runtime: &dyn McpServer) {} } From abb0c36126b0a397bc20a1de36c5a5a80924a01e Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Wed, 17 Sep 2025 20:43:06 -0300 Subject: [PATCH 27/33] feat!: add Streamable HTTP Client , multiple refactoring and improvements (#98) * Add Streamable HTTP Client and multiple refactoring and improvements * chore: typos * chore: update readme * merge main --- .release-manifest.json | 20 +- Cargo.lock | 94 +- Cargo.toml | 17 +- README.md | 147 +++- crates/rust-mcp-sdk/Cargo.toml | 41 +- crates/rust-mcp-sdk/README.md | 145 ++- crates/rust-mcp-sdk/src/error.rs | 37 +- .../src/hyper_servers/app_state.rs | 11 +- .../src/hyper_servers/routes/hyper_utils.rs | 96 +- .../src/hyper_servers/routes/sse_routes.rs | 18 +- .../routes/streamable_http_routes.rs | 7 +- .../rust-mcp-sdk/src/hyper_servers/server.rs | 13 +- .../src/hyper_servers/session_store.rs | 24 - crates/rust-mcp-sdk/src/id_generator.rs | 5 + .../src/id_generator/fast_id_generator.rs | 53 ++ .../src/id_generator/uuid_generator.rs | 18 + crates/rust-mcp-sdk/src/lib.rs | 9 +- .../src/mcp_runtimes/client_runtime.rs | 531 ++++++++--- .../client_runtime/mcp_client_runtime.rs | 24 +- .../client_runtime/mcp_client_runtime_core.rs | 33 +- .../src/mcp_runtimes/server_runtime.rs | 66 +- .../server_runtime/mcp_server_runtime.rs | 2 +- .../server_runtime/mcp_server_runtime_core.rs | 2 +- crates/rust-mcp-sdk/src/mcp_traits.rs | 2 + .../src/mcp_traits/id_generator.rs | 12 + .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 49 +- crates/rust-mcp-sdk/src/utils.rs | 43 +- crates/rust-mcp-sdk/tests/check_imports.rs | 5 +- crates/rust-mcp-sdk/tests/common/common.rs | 57 +- .../rust-mcp-sdk/tests/common/mock_server.rs | 528 +++++++++++ .../rust-mcp-sdk/tests/common/test_client.rs | 163 ++++ .../rust-mcp-sdk/tests/common/test_server.rs | 19 +- .../tests/test_streamable_http_client.rs | 823 ++++++++++++++++++ ...http.rs => test_streamable_http_server.rs} | 7 +- crates/rust-mcp-transport/Cargo.toml | 4 +- crates/rust-mcp-transport/README.md | 4 +- crates/rust-mcp-transport/src/client_sse.rs | 101 ++- .../src/client_streamable_http.rs | 515 +++++++++++ crates/rust-mcp-transport/src/constants.rs | 3 + crates/rust-mcp-transport/src/error.rs | 71 +- crates/rust-mcp-transport/src/lib.rs | 17 +- crates/rust-mcp-transport/src/mcp_stream.rs | 37 + .../src/message_dispatcher.rs | 82 +- crates/rust-mcp-transport/src/sse.rs | 4 +- crates/rust-mcp-transport/src/stdio.rs | 67 +- crates/rust-mcp-transport/src/transport.rs | 35 +- crates/rust-mcp-transport/src/utils.rs | 28 +- .../src/utils/http_utils.rs | 125 ++- .../src/utils/sse_parser.rs | 320 +++++++ .../src/utils/streamable_http_stream.rs | 374 ++++++++ .../rust-mcp-transport/tests/check_imports.rs | 5 +- development.md | 6 +- .../.gitignore | 0 .../Cargo.toml | 5 +- .../README.md | 8 +- .../src/handler.rs | 0 .../src/main.rs | 0 .../src/tools.rs | 0 .../Cargo.toml | 7 +- .../README.md | 8 +- .../src/handler.rs | 0 .../src/main.rs | 0 .../src/tools.rs | 0 .../.gitignore | 0 .../Cargo.toml | 5 +- .../README.md | 4 +- .../src/handler.rs | 0 .../src/main.rs | 0 .../src/tools.rs | 0 .../Cargo.toml | 1 + .../README.md | 2 +- .../src/handler.rs | 7 +- .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 1 + examples/simple-mcp-client-sse/Cargo.toml | 2 + examples/simple-mcp-client-sse/src/main.rs | 13 +- .../Cargo.toml | 6 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 0 .../Cargo.toml | 6 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 0 .../Cargo.toml | 29 + .../README.md | 40 + .../src/handler.rs | 72 ++ .../src/inquiry_utils.rs | 222 +++++ .../src/main.rs | 95 ++ .../Cargo.toml | 29 + .../README.md | 40 + .../src/handler.rs | 10 + .../src/inquiry_utils.rs | 222 +++++ .../src/main.rs | 99 +++ 99 files changed, 5282 insertions(+), 581 deletions(-) create mode 100644 crates/rust-mcp-sdk/src/id_generator.rs create mode 100644 crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs create mode 100644 crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs create mode 100644 crates/rust-mcp-sdk/tests/common/mock_server.rs create mode 100644 crates/rust-mcp-sdk/tests/common/test_client.rs create mode 100644 crates/rust-mcp-sdk/tests/test_streamable_http_client.rs rename crates/rust-mcp-sdk/tests/{test_streamable_http.rs => test_streamable_http_server.rs} (99%) create mode 100644 crates/rust-mcp-transport/src/client_streamable_http.rs create mode 100644 crates/rust-mcp-transport/src/constants.rs create mode 100644 crates/rust-mcp-transport/src/utils/sse_parser.rs create mode 100644 crates/rust-mcp-transport/src/utils/streamable_http_stream.rs rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/.gitignore (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/Cargo.toml (83%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/README.md (81%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/handler.rs (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/main.rs (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/tools.rs (100%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/Cargo.toml (85%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/README.md (84%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/handler.rs (100%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/main.rs (100%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/tools.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/.gitignore (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/Cargo.toml (84%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/README.md (95%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/handler.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/main.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/tools.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/Cargo.toml (88%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/README.md (97%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/handler.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/main.rs (99%) rename examples/{simple-mcp-client => simple-mcp-client-stdio-core}/Cargo.toml (85%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/README.md (97%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/handler.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/main.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio}/Cargo.toml (87%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/README.md (97%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/handler.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/main.rs (100%) create mode 100644 examples/simple-mcp-client-streamable-http-core/Cargo.toml create mode 100644 examples/simple-mcp-client-streamable-http-core/README.md create mode 100644 examples/simple-mcp-client-streamable-http-core/src/handler.rs create mode 100644 examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs create mode 100644 examples/simple-mcp-client-streamable-http-core/src/main.rs create mode 100644 examples/simple-mcp-client-streamable-http/Cargo.toml create mode 100644 examples/simple-mcp-client-streamable-http/README.md create mode 100644 examples/simple-mcp-client-streamable-http/src/handler.rs create mode 100644 examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs create mode 100644 examples/simple-mcp-client-streamable-http/src/main.rs diff --git a/.release-manifest.json b/.release-manifest.json index 97a0f63..a645da6 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,15 @@ { "crates/rust-mcp-sdk": "0.6.3", "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.5.1", - "examples/hello-world-mcp-server": "0.1.31", - "examples/hello-world-mcp-server-core": "0.1.22", - "examples/simple-mcp-client": "0.1.31", - "examples/simple-mcp-client-core": "0.1.31", - "examples/hello-world-server-core-streamable-http": "0.1.22", - "examples/hello-world-server-streamable-http": "0.1.31", - "examples/simple-mcp-client-core-sse": "0.1.22", - "examples/simple-mcp-client-sse": "0.1.22" + "crates/rust-mcp-transport": "0.5.0", + "examples/hello-world-mcp-server-stdio": "0.1.28", + "examples/hello-world-mcp-server-stdio-core": "0.1.19", + "examples/simple-mcp-client-stdio": "0.1.28", + "examples/simple-mcp-client-stdio-core": "0.1.28", + "examples/hello-world-server-streamable-http-core": "0.1.19", + "examples/hello-world-server-streamable-http": "0.1.28", + "examples/simple-mcp-client-sse-core": "0.1.19", + "examples/simple-mcp-client-sse": "0.1.19", + "examples/simple-mcp-client-streamable-http": "0.1.0", + "examples/simple-mcp-client-streamable-http-core": "0.1.0" } diff --git a/Cargo.lock b/Cargo.lock index c10e354..c3c4462 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,10 +257,11 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.34" +version = "1.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" +checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -381,9 +382,9 @@ checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" dependencies = [ "powerfmt", ] @@ -451,6 +452,12 @@ dependencies = [ "instant", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" + [[package]] name = "fnv" version = "1.0.7" @@ -687,8 +694,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] -name = "hello-world-mcp-server" -version = "0.1.31" +name = "hello-world-mcp-server-stdio" +version = "0.1.28" dependencies = [ "async-trait", "futures", @@ -701,8 +708,8 @@ dependencies = [ ] [[package]] -name = "hello-world-mcp-server-core" -version = "0.1.22" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -713,8 +720,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-core-streamable-http" -version = "0.1.22" +name = "hello-world-server-streamable-http" +version = "0.1.31" dependencies = [ "async-trait", "futures", @@ -727,8 +734,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-streamable-http" -version = "0.1.31" +name = "hello-world-server-streamable-http-core" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -1684,6 +1691,7 @@ dependencies = [ "async-trait", "axum", "axum-server", + "base64 0.22.1", "futures", "hyper 1.7.0", "reqwest", @@ -1698,6 +1706,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "wiremock", ] [[package]] @@ -1903,8 +1912,8 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client" -version = "0.1.31" +name = "simple-mcp-client-sse" +version = "0.1.22" dependencies = [ "async-trait", "colored", @@ -1914,11 +1923,13 @@ dependencies = [ "serde_json", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core" -version = "0.1.31" +name = "simple-mcp-client-sse-core" +version = "0.1.19" dependencies = [ "async-trait", "colored", @@ -1928,11 +1939,41 @@ dependencies = [ "serde_json", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core-sse" -version = "0.1.22" +name = "simple-mcp-client-stdio" +version = "0.1.28" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-stdio-core" +version = "0.1.28" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-streamable-http" +version = "0.1.0" dependencies = [ "async-trait", "colored", @@ -1947,8 +1988,8 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client-sse" -version = "0.1.22" +name = "simple-mcp-client-streamable-http-core" +version = "0.1.0" dependencies = [ "async-trait", "colored", @@ -2088,12 +2129,11 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "8ca967379f9d8eb8058d86ed467d81d03e81acd45757e4ca341c24affbe8e8e3" dependencies = [ "deranged", - "itoa", "num-conv", "powerfmt", "serde", @@ -2103,15 +2143,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "a9108bb380861b07264b950ded55a44a14a4adc68b9f5efd85aafc3aa4d40a68" [[package]] name = "time-macros" -version = "0.2.22" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +checksum = "7182799245a7264ce590b349d90338f1c1affad93d2639aed5f8f69c090b334c" dependencies = [ "num-conv", "time-core", diff --git a/Cargo.toml b/Cargo.toml index b4f7cca..711204d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,14 +4,17 @@ members = [ "crates/rust-mcp-macros", "crates/rust-mcp-sdk", "crates/rust-mcp-transport", - "examples/simple-mcp-client", - "examples/simple-mcp-client-core", - "examples/hello-world-mcp-server", - "examples/hello-world-mcp-server-core", + "examples/simple-mcp-client-stdio", + "examples/simple-mcp-client-stdio-core", + "examples/hello-world-mcp-server-stdio", + "examples/hello-world-mcp-server-stdio-core", "examples/hello-world-server-streamable-http", - "examples/hello-world-server-core-streamable-http", + "examples/hello-world-server-streamable-http-core", "examples/simple-mcp-client-sse", - "examples/simple-mcp-client-core-sse", + "examples/simple-mcp-client-sse-core", + "examples/simple-mcp-client-streamable-http", + "examples/simple-mcp-client-streamable-http-core", + ] [workspace.dependencies] @@ -39,7 +42,7 @@ tracing-subscriber = { version = "0.3", features = [ "std", "fmt", ] } - +base64 = "0.22" axum = "0.8" rustls = "0.23" tokio-rustls = "0.26" diff --git a/README.md b/README.md index 1d334d6..c1e201c 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -32,13 +32,12 @@ This project supports following transports: πŸš€ The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - βœ… Streamable HTTP Support for MCP Servers - βœ… DNS Rebinding Protection - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients +- βœ… Streamable HTTP Support for MCP Clients - ⬜ Resumability - ⬜ Authentication / Oauth @@ -49,6 +48,7 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) @@ -110,7 +110,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -191,7 +191,8 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { + + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +206,7 @@ impl ServerHandler for MyServerHandler { --- -πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -283,6 +284,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -294,8 +297,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +πŸ‘‰ see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -306,6 +383,8 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + ## Getting Started @@ -344,9 +423,15 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -356,12 +441,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -373,17 +452,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -395,6 +463,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -416,9 +495,13 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. + - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features @@ -450,9 +533,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -465,7 +548,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` @@ -478,10 +561,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -510,7 +593,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 50d5e7f..99d6f86 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -24,15 +24,17 @@ futures = { workspace = true } thiserror = { workspace = true } axum = { workspace = true, optional = true } -uuid = { workspace = true, features = ["v4"], optional = true } +uuid = { workspace = true, features = ["v4"] } tokio-stream = { workspace = true, optional = true } axum-server = { version = "0.7", features = [], optional = true } tracing.workspace = true +base64.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } [dev-dependencies] +wiremock = "0.5" reqwest = { workspace = true, default-features = false, features = [ "stream", "rustls-tls", @@ -51,48 +53,55 @@ default = [ "client", "server", "macros", + "stdio", + "sse", + "streamable-http", "hyper-server", "ssl", "2025_06_18", ] # All features enabled by default -server = ["rust-mcp-transport/stdio"] # Server feature -client = ["rust-mcp-transport/stdio", "rust-mcp-transport/sse"] # Client feature -hyper-server = [ - "axum", - "axum-server", - "hyper", - "server", - "uuid", - "tokio-stream", - "rust-mcp-transport/sse", -] + +sse = ["rust-mcp-transport/sse"] +streamable-http = ["rust-mcp-transport/streamable-http"] +stdio = ["rust-mcp-transport/stdio"] + +server = [] # Server feature +client = [] # Client feature +hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream"] ssl = ["axum-server/tls-rustls"] tls-no-provider = ["axum-server/tls-rustls-no-provider"] macros = ["rust-mcp-macros/sdk"] -# enables mcp protocol version 2025_06_18 -2025_06_18 = [ +# enables mcp protocol version 2025-06-18 +2025-06-18 = [ "rust-mcp-schema/2025_06_18", "rust-mcp-macros/2025_06_18", "rust-mcp-transport/2025_06_18", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_06_18 = ["2025-06-18"] # enables mcp protocol version 2025_03_26 -2025_03_26 = [ +2025-03-26 = [ "rust-mcp-schema/2025_03_26", "rust-mcp-macros/2025_03_26", "rust-mcp-transport/2025_03_26", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_03_26 = ["2025-03-26"] + # enables mcp protocol version 2024_11_05 -2024_11_05 = [ +2024-11-05 = [ "rust-mcp-schema/2024_11_05", "rust-mcp-macros/2024_11_05", "rust-mcp-transport/2024_11_05", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2024_11_05 = ["2024-11-05"] [lints] workspace = true diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index cbe7318..8036022 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -32,13 +32,12 @@ This project supports following transports: πŸš€ The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - βœ… Streamable HTTP Support for MCP Servers - βœ… DNS Rebinding Protection - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients +- βœ… Streamable HTTP Support for MCP Clients - ⬜ Resumability - ⬜ Authentication / Oauth @@ -49,6 +48,7 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) @@ -110,7 +110,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -205,7 +205,7 @@ impl ServerHandler for MyServerHandler { --- -πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -283,6 +283,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -294,8 +296,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +πŸ‘‰ see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -306,6 +382,8 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + ## Getting Started @@ -344,9 +422,15 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -356,12 +440,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -373,17 +451,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -395,6 +462,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -416,9 +494,14 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. + +- `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features @@ -449,9 +532,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -464,7 +547,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` @@ -477,10 +560,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +592,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 3de8d98..3879526 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -11,25 +11,36 @@ pub type SdkResult = core::result::Result; #[derive(Debug, Error)] pub enum McpSdkError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] RpcError(#[from] RpcError), + #[error("{0}")] - IoError(#[from] std::io::Error), - #[error("{0}")] - TransportError(#[from] TransportError), - #[error("{0}")] - JoinError(#[from] JoinError), - #[error("{0}")] - AnyError(Box<(dyn std::error::Error + Send + Sync)>), - #[error("{0}")] - SdkError(#[from] crate::schema::schema_utils::SdkError), + Join(#[from] JoinError), + #[cfg(feature = "hyper-server")] #[error("{0}")] - TransportServerError(#[from] TransportServerError), - #[error("Incompatible mcp protocol version: requested:{0} current:{1}")] - IncompatibleProtocolVersion(String, String), + HyperServer(#[from] TransportServerError), + #[error("{0}")] - ParseProtocolVersionError(#[from] ParseProtocolVersionError), + SdkError(#[from] crate::schema::schema_utils::SdkError), + + #[error("Protocol error: {kind}")] + Protocol { kind: ProtocolErrorKind }, +} + +// Sub-enum for protocol-related errors +#[derive(Debug, Error)] +pub enum ProtocolErrorKind { + #[error("Incompatible protocol version: requested {requested}, current {current}")] + IncompatibleVersion { requested: String, current: String }, + #[error("Failed to parse protocol version: {0}")] + ParseError(#[from] ParseProtocolVersionError), } impl McpSdkError { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index 0c1dcf3..ff6d5b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -1,11 +1,9 @@ use std::{sync::Arc, time::Duration}; -use crate::schema::InitializeResult; -use rust_mcp_transport::TransportOptions; - +use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; - -use super::{session_store::SessionStore, IdGenerator}; +use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; +use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server /// @@ -14,7 +12,8 @@ use super::{session_store::SessionStore, IdGenerator}; #[derive(Clone)] pub struct AppState { pub session_store: Arc, - pub id_generator: Arc, + pub id_generator: Arc>, + pub stream_id_gen: Arc, pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index daf5d94..da69c67 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -6,7 +6,7 @@ use crate::{ }, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::mcp_handler::McpServerHandler, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, utils::validate_mcp_protocol_version, }; @@ -22,13 +22,12 @@ use axum::{ }; use futures::stream; use hyper::{header, HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, SseTransport}; +use rust_mcp_transport::{ + SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, +}; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; -pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; -pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; - const DUPLEX_BUFFER_SIZE: usize = 8192; async fn create_sse_stream( @@ -41,11 +40,11 @@ async fn create_sse_stream( let payload_string = payload.map(|p| p.to_string()); // TODO: this logic should be moved out after refactoing the mcp_stream.rs - let result = payload_string + let payload_contains_request = payload_string .as_ref() .map(|json_str| contains_request(json_str)) .unwrap_or(Ok(false)); - let Ok(payload_contains_request) = result else { + let Ok(payload_contains_request) = payload_contains_request else { return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); }; @@ -54,18 +53,20 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + let transport = Arc::new( + SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?, + ); - let stream_id = if standalone { + let stream_id: StreamId = if standalone { DEFAULT_STREAM_ID.to_string() } else { - state.id_generator.generate() + state.stream_id_gen.generate() }; let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); @@ -85,6 +86,7 @@ async fn create_sse_stream( // Construct SSE stream let reader = BufReader::new(write_rx); + // outgoing messages from server to the client let message_stream = stream::unfold(reader, |mut reader| async move { let mut line = String::new(); @@ -117,12 +119,12 @@ async fn create_sse_stream( // TODO: this function will be removed after refactoring the readable stream of the transports // so we would deserialize the string syncronousely and have more control over the flow -// this function could potentially add a 20-250 ns overhead which could be avoided +// this function may incur a slight runtime cost which could be avoided after refactoring fn contains_request(json_str: &str) -> Result { let value: serde_json::Value = serde_json::from_str(json_str)?; match value { serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), - serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| { item.as_object() .map(|obj| obj.contains_key("id") && obj.contains_key("method")) .unwrap_or(false) @@ -131,6 +133,19 @@ fn contains_request(json_str: &str) -> Result { } } +fn is_result(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("result")), + serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + item.as_object() + .map(|obj| obj.contains_key("result")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + pub async fn create_standalone_stream( session_id: SessionId, state: Arc, @@ -224,7 +239,12 @@ async fn single_shot_stream( tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + Arc::new(transport), + &stream_id, + ping_interval, + payload_string, + ) .await { Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), @@ -233,7 +253,6 @@ async fn single_shot_stream( let _ = runtime.remove_transport(&stream_id).await; }); - // Construct SSE stream let mut reader = BufReader::new(write_rx); let mut line = String::new(); let response = match reader.read_line(&mut line).await { @@ -310,15 +329,34 @@ pub async fn process_incoming_message( match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); - - create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - ) - .await + // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport + // it should be processed by the same transport , therefore no need to call create_sse_stream + let Ok(is_result) = is_result(payload) else { + return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + }; + + if is_result { + match runtime + .consume_payload_string(DEFAULT_STREAM_ID, payload) + .await + { + Ok(()) => Ok((StatusCode::ACCEPTED, Json(())).into_response()), + Err(err) => Ok(( + StatusCode::BAD_REQUEST, + Json(SdkError::internal_error().with_message(err.to_string().as_ref())), + ) + .into_response()), + } + } else { + create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await + } } None => { let error = SdkError::session_not_found(); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index a014e94..27a16b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,3 +1,4 @@ +use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ hyper_servers::{ @@ -90,13 +91,17 @@ pub async fn handle_sse( let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); // create a transport for sending/receiving messages - let transport = SseTransport::new( + let Ok(transport) = SseTransport::new( read_rx, write_tx, read_tx, Arc::clone(&state.transport_options), - ) - .unwrap(); + ) else { + return Err(TransportServerError::TransportError( + "Failed to create SSE transport".to_string(), + )); + }; + let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and let server: Arc = server_runtime::create_server_instance( @@ -115,7 +120,12 @@ pub async fn handle_sse( // Start the server tokio::spawn(async move { match server - .start_stream(transport, DEFAULT_STREAM_ID, state.ping_interval, None) + .start_stream( + Arc::new(transport), + DEFAULT_STREAM_ID, + state.ping_interval, + None, + ) .await { Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 83cc372..00d46c0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,4 +1,4 @@ -use super::hyper_utils::{start_new_session, MCP_SESSION_ID_HEADER}; +use super::hyper_utils::start_new_session; use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, @@ -14,6 +14,7 @@ use crate::{ }, utils::valid_initialize_method, }; +use axum::routing::get; use axum::{ extract::{Query, State}, middleware, @@ -22,11 +23,9 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::SessionId; +use rust_mcp_transport::{SessionId, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; -use axum::routing::get; - pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index f093da3..1c3b3cf 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,6 +1,8 @@ use crate::{ - error::SdkResult, mcp_server::hyper_runtime::HyperRuntime, - mcp_traits::mcp_handler::McpServerHandler, + error::SdkResult, + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_server::hyper_runtime::HyperRuntime, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; @@ -17,11 +19,11 @@ use super::{ app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - IdGenerator, InMemorySessionStore, UuidGenerator, + InMemorySessionStore, }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::TransportOptions; +use rust_mcp_transport::{SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -43,7 +45,7 @@ pub struct HyperServerOptions { pub port: u16, /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, + pub session_id_generator: Option>>, /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, @@ -258,6 +260,7 @@ impl HyperServer { .session_id_generator .take() .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index 95b2158..4384b1a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -5,7 +5,6 @@ use async_trait::async_trait; pub use in_memory::*; use rust_mcp_transport::SessionId; use tokio::sync::Mutex; -use uuid::Uuid; use crate::mcp_server::ServerRuntime; @@ -46,26 +45,3 @@ pub trait SessionStore: Send + Sync { async fn has(&self, session: &SessionId) -> bool; } - -/// Trait for generating session identifiers -/// -/// Implementors must be Send and Sync to support concurrent access. -pub trait IdGenerator: Send + Sync { - fn generate(&self) -> SessionId; -} - -/// Struct implementing the IdGenerator trait using UUID v4 -/// -/// This is a simple wrapper around the uuid crate's Uuid::new_v4 function -/// to generate unique session identifiers. -pub struct UuidGenerator {} - -impl IdGenerator for UuidGenerator { - /// Generates a new UUID v4-based session identifier - /// - /// # Returns - /// * `SessionId` - A new UUID-based session identifier as a String - fn generate(&self) -> SessionId { - Uuid::new_v4().to_string() - } -} diff --git a/crates/rust-mcp-sdk/src/id_generator.rs b/crates/rust-mcp-sdk/src/id_generator.rs new file mode 100644 index 0000000..54f0e72 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator.rs @@ -0,0 +1,5 @@ +mod fast_id_generator; +mod uuid_generator; +pub use crate::mcp_traits::IdGenerator; +pub use fast_id_generator::*; +pub use uuid_generator::*; diff --git a/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs new file mode 100644 index 0000000..fc2e976 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs @@ -0,0 +1,53 @@ +use crate::mcp_traits::IdGenerator; +use base64::Engine; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// An [`IdGenerator`] implementation optimized for lightweight, locally-scoped identifiers. +/// +/// This generator produces short, incrementing identifiers that are Base64-encoded. +/// This makes it well-suited for cases such as `StreamId` generation, where: +/// - IDs only need to be unique within a single process or session +/// - Predictability is acceptable +/// - Shorter, more human-readable identifiers are desirable +/// +pub struct FastIdGenerator { + counter: AtomicU64, + ///Optional prefix for readability + prefix: &'static str, +} + +impl FastIdGenerator { + /// Creates a new ID generator with an optional prefix. + /// + /// # Arguments + /// * `prefix` - A static string to prepend to IDs (e.g., "sid_"). + pub fn new(prefix: Option<&'static str>) -> Self { + FastIdGenerator { + counter: AtomicU64::new(0), + prefix: prefix.unwrap_or_default(), + } + } +} + +impl IdGenerator for FastIdGenerator +where + T: From, +{ + /// Generates a new session ID as a short Base64-encoded string. + /// + /// Increments an internal counter atomically and encodes it in Base64 URL-safe format. + /// The resulting ID is prefixed (if provided) and typically 8–12 characters long. + /// + /// # Returns + /// * `SessionId` - A short, unique session ID (e.g., "sid_BBBB" or "BBBB"). + fn generate(&self) -> T { + let id = self.counter.fetch_add(1, Ordering::Relaxed); + let bytes = id.to_le_bytes(); + let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + if self.prefix.is_empty() { + T::from(encoded) + } else { + T::from(format!("{}{}", self.prefix, encoded)) + } + } +} diff --git a/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs new file mode 100644 index 0000000..2f0dc21 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs @@ -0,0 +1,18 @@ +use crate::mcp_traits::IdGenerator; +use uuid::Uuid; + +/// An [`IdGenerator`] implementation that uses UUID v4 to create unique identifiers. +/// +/// This generator produces random UUIDs (version 4), which are highly unlikely +/// to collide and difficult to predict. It is therefore well-suited for +/// generating identifiers such as `SessionId` or other values where uniqueness is important. +pub struct UuidGenerator; + +impl IdGenerator for UuidGenerator +where + T: From, +{ + fn generate(&self) -> T { + T::from(Uuid::new_v4().to_string()) + } +} diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 1ea23df..a33f889 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -21,7 +21,7 @@ pub mod mcp_client { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) for an example. + //! Refer to [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) for an example. //! //! //! - **client_runtime_core**: If you need more control over MCP messages, consider using @@ -30,7 +30,7 @@ pub mod mcp_client { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) for an example. + //! Refer to [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) for an example. pub use super::mcp_handlers::mcp_client_handler::ClientHandler; pub use super::mcp_handlers::mcp_client_handler_core::ClientHandlerCore; pub use super::mcp_runtimes::client_runtime::mcp_client_runtime as client_runtime; @@ -53,7 +53,7 @@ pub mod mcp_server { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) for an example. //! //! //! - **server_runtime_core**: If you need more control over MCP messages, consider using @@ -62,7 +62,7 @@ pub mod mcp_server { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/hello-world-mcp-server-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) for an example. pub use super::mcp_handlers::mcp_server_handler::ServerHandler; pub use super::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; @@ -93,4 +93,5 @@ pub mod macros { pub use rust_mcp_macros::*; } +pub mod id_generator; pub mod schema; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 7ee0815..2093dc3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,12 +1,17 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; - +use crate::error::{McpSdkError, SdkResult}; +use crate::id_generator::FastIdGenerator; +use crate::mcp_traits::mcp_client::McpClient; +use crate::mcp_traits::mcp_handler::McpClientHandler; +use crate::mcp_traits::IdGenerator; +use crate::utils::ensure_server_protocole_compatibility; use crate::{ mcp_traits::{RequestIdGen, RequestIdGenNumeric}, schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, + self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, + ServerMessage, ServerMessages, }, InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, RequestId, RpcError, ServerResult, @@ -16,63 +21,100 @@ use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; -use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::{ - sync::{Arc, RwLock}, - time::Duration, -}; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions}; +use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::sync::Mutex; +use tokio::sync::{watch, Mutex}; -use crate::error::{McpSdkError, SdkResult}; -use crate::mcp_traits::mcp_client::McpClient; -use crate::mcp_traits::mcp_handler::McpClientHandler; -use crate::utils::ensure_server_protocole_compatibility; +pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; + +// Define a type alias for the TransportDispatcher trait object +type TransportDispatcherType = dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, +>; +type TransportType = Arc; pub struct ClientRuntime { - // The transport interface for handling messages between client and server - transport: Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + // A thread-safe map storing transport types + transport_map: tokio::sync::RwLock>, // The handler for processing MCP messages handler: Box, - // // Information about the server + // Information about the server client_details: InitializeRequestParams, - // Details about the connected server - server_details: Arc>>, handlers: Mutex>>>, + // Generator for unique request IDs request_id_gen: Box, + // Generator for stream IDs + stream_id_gen: FastIdGenerator, + #[cfg(feature = "streamable-http")] + // Optional configuration for streamable transport + transport_options: Option, + // Flag indicating whether the client has been shut down + is_shut_down: Mutex, + // Session ID + session_id: tokio::sync::RwLock>, + // Details about the connected server + server_details_tx: watch::Sender>, + server_details_rx: watch::Receiver>, } impl ClientRuntime { pub(crate) fn new( client_details: InitializeRequestParams, - transport: impl Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, + transport: TransportType, handler: Box, ) -> Self { + let mut map: HashMap = HashMap::new(); + map.insert(DEFAULT_STREAM_ID.to_string(), transport); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); Self { - transport: Arc::new(transport), + transport_map: tokio::sync::RwLock::new(map), handler, client_details, - server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + #[cfg(feature = "streamable-http")] + transport_options: None, + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, } } - async fn initialize_request(&self) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + pub(crate) fn new_instance( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: Box, + ) -> Self { + let map: HashMap = HashMap::new(); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); + Self { + transport_map: tokio::sync::RwLock::new(map), + handler, + client_details, + handlers: Mutex::new(vec![]), + transport_options: Some(transport_options), + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, + } + } + + async fn initialize_request(self: Arc) -> SdkResult<()> { let request = InitializeRequest::new(self.client_details.clone()); let result: ServerResult = self.request(request.into(), None).await?.try_into()?; @@ -81,9 +123,15 @@ impl ClientRuntime { &self.client_details.protocol_version, &initialize_result.protocol_version, )?; - // store server details self.set_server_details(initialize_result)?; + + #[cfg(feature = "streamable-http")] + // try to create a sse stream for server initiated messages , if supported by the server + if let Err(error) = self.clone().create_sse_stream().await { + tracing::warn!("{error}"); + } + // send a InitializedNotification to the server self.send_notification(InitializedNotification::new(None).into()) .await?; @@ -92,21 +140,14 @@ impl ClientRuntime { .with_message("Incorrect response to InitializeRequest!".into()) .into()); } + Ok(()) } pub(crate) async fn handle_message( &self, message: ServerMessage, - transport: &Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + transport: &TransportType, ) -> SdkResult> { let response = match message { ServerMessage::Request(jsonrpc_request) => { @@ -162,28 +203,26 @@ impl ClientRuntime { }; Ok(response) } -} -#[async_trait] -impl McpClient for ClientRuntime { - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch, - { - (self.transport.message_sender().clone()) as _ - } + async fn start_standalone(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; - async fn start(self: Arc) -> SdkResult<()> { //TODO: improve the flow - let mut stream = self.transport.start().await?; - let transport = self.transport.clone(); + let mut stream = transport.start().await?; + + let transport_clone = transport.clone(); let mut error_io_stream = transport.error_stream().write().await; let error_io_stream = error_io_stream.take(); let self_clone = Arc::clone(&self); let self_clone_err = Arc::clone(&self); + // task reading from the error stream let err_task = tokio::spawn(async move { let self_ref = &*self_clone_err; @@ -191,7 +230,7 @@ impl McpClient for ClientRuntime { let mut reader = BufReader::new(error_input).lines(); loop { tokio::select! { - should_break = self_ref.transport.is_shut_down() =>{ + should_break = transport_clone.is_shut_down() =>{ if should_break { break; } @@ -221,14 +260,10 @@ impl McpClient for ClientRuntime { Ok::<(), McpSdkError>(()) }); - let transport = self.transport.clone(); + let transport = transport.clone(); + // main task reading from mcp_message stream let main_task = tokio::spawn(async move { - let sender = self_clone.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; while let Some(mcp_messages) = stream.next().await { let self_ref = &*self_clone; @@ -239,7 +274,7 @@ impl McpClient for ClientRuntime { match result { Ok(result) => { if let Some(result) = result { - sender + transport .send_message(ClientMessages::Single(result), None) .await?; } @@ -260,7 +295,7 @@ impl McpClient for ClientRuntime { let results: Vec<_> = results.into_iter().flatten().collect(); if !results.is_empty() { - sender + transport .send_message(ClientMessages::Batch(results), None) .await?; } @@ -271,71 +306,349 @@ impl McpClient for ClientRuntime { }); // send initialize request to the MCP server - self.initialize_request().await?; + self.clone().initialize_request().await?; let mut lock = self.handlers.lock().await; lock.push(main_task); lock.push(err_task); + Ok(()) + } + pub(crate) async fn store_transport( + &self, + stream_id: &str, + transport: TransportType, + ) -> SdkResult<()> { + let mut transport_map = self.transport_map.write().await; + tracing::trace!("save transport for stream id : {}", stream_id); + transport_map.insert(stream_id.to_string(), transport); Ok(()) } - fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { - match self.server_details.write() { - Ok(mut details) => { - *details = Some(server_details); - Ok(()) - } - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - Err(_) => Err(RpcError::internal_error() - .with_message("Internal Error: Failed to acquire write lock.".to_string()) - .into()), - } + pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult { + let transport_map = self.transport_map.read().await; + transport_map.get(stream_id).cloned().ok_or_else(|| { + RpcError::internal_error() + .with_message(format!("Transport for key {stream_id} not found")) + .into() + }) } - fn client_info(&self) -> &InitializeRequestParams { - &self.client_details + + #[cfg(feature = "streamable-http")] + pub(crate) async fn new_transport( + &self, + session_id: Option, + standalone: bool, + ) -> SdkResult< + impl TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > { + let options = self + .transport_options + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + let transport = ClientStreamableTransport::new(options, session_id, standalone)?; + + Ok(transport) } - fn server_info(&self) -> Option { - if let Ok(details) = self.server_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None + + #[cfg(feature = "streamable-http")] + pub(crate) async fn create_sse_stream(self: Arc) -> SdkResult<()> { + let stream_id: StreamId = DEFAULT_STREAM_ID.into(); + let session_id = self.session_id.read().await.clone(); + let transport: Arc< + dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > = Arc::new(self.new_transport(session_id, true).await?); + let mut stream = transport.start().await?; + self.store_transport(&stream_id, transport.clone()).await?; + + let self_clone = Arc::clone(&self); + + let main_task = tokio::spawn(async move { + loop { + if let Some(mcp_messages) = stream.next().await { + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + + if let Some(result) = result { + transport + .send_message(ClientMessages::Single(result), None) + .await?; + } + } + ServerMessages::Batch(server_messages) => { + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| { + self.handle_message(server_message, &transport) + }) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport + .send_message(ClientMessages::Batch(results), None) + .await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID) { + return Ok::<_, McpSdkError>(()); + } + } else { + // end of stream + return Ok::<_, McpSdkError>(()); + } + } + }); + + let mut lock = self_clone.handlers.lock().await; + lock.push(main_task); + + Ok(()) + } + + #[cfg(feature = "streamable-http")] + pub(crate) async fn start_stream( + &self, + messages: ClientMessages, + timeout: Option, + ) -> SdkResult> { + use futures::stream::{AbortHandle, Abortable}; + let stream_id: StreamId = self.stream_id_gen.generate(); + let session_id = self.session_id.read().await.clone(); + let no_session_id = session_id.is_none(); + + let has_request = match &messages { + ClientMessages::Single(client_message) => client_message.is_request(), + ClientMessages::Batch(client_messages) => { + client_messages.iter().any(|m| m.is_request()) + } + }; + + let transport = Arc::new(self.new_transport(session_id, false).await?); + + let mut stream = transport.start().await?; + + self.store_transport(&stream_id, transport).await?; + + let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove + + let send_task = async { + let result = transport.send_message(messages, timeout).await?; + + if no_session_id { + if let Some(request_id) = transport.session_id().await.clone() { + let mut guard = self.session_id.write().await; + *guard = Some(request_id) + } + } + + Ok::<_, McpSdkError>(result) + }; + + if !has_request { + return send_task.await; } + + let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair(); + + let receive_task = async { + loop { + tokio::select! { + Some(mcp_messages) = stream.next() =>{ + + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + if let Some(result) = result { + transport.send_message(ClientMessages::Single(result), None).await?; + } + } + ServerMessages::Batch(server_messages) => { + + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| self.handle_message(server_message, &transport)) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport.send_message(ClientMessages::Batch(results), None).await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID){ + return Ok::<_, McpSdkError>(()); + } + } + } + } + }; + + let receive_task = Abortable::new(receive_task, abort_recv_reg); + + // Pin the tasks to ensure they are not moved + tokio::pin!(send_task); + tokio::pin!(receive_task); + + // Run both tasks with cancellation logic + let (send_res, _) = tokio::select! { + res = &mut send_task => { + // cancel the receive_task task, to cover the case where send_task returns with error + abort_recv_handle.abort(); + (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation) + } + res = &mut receive_task => { + (send_task.await, res) + } + }; + send_res } +} +#[async_trait] +impl McpClient for ClientRuntime { async fn send( &self, message: MessageFromClient, request_id: Option, - timeout: Option, + request_timeout: Option, ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = self + .start_stream(ClientMessages::Single(mcp_message), request_timeout) + .await?; + return response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()); + } + } + + let transport_map = self.transport_map.read().await; + + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; let outgoing_request_id = self .request_id_gen .request_id_for_message(&message, request_id); let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + let response = transport + .send_message(ClientMessages::Single(mcp_message), request_timeout) + .await?; + response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()) + } - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) - .await? - .map(|res| res.as_single()) - .transpose()?; + async fn send_batch( + &self, + messages: Vec, + timeout: Option, + ) -> SdkResult>> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let result = self + .start_stream(ClientMessages::Batch(messages), timeout) + .await?; + // let response = self.start_stream(&stream_id, request_id, message).await?; + return result + .map(|r| r.as_batch()) + .transpose() + .map_err(|err| err.into()); + } + } - Ok(response) + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + transport + .send_batch(messages, timeout) + .await + .map_err(|err| err.into()) + } + + async fn start(self: Arc) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + self.initialize_request().await?; + return Ok(()); + } + } + + self.start_standalone().await + } + + fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { + self.server_details_tx + .send(Some(server_details)) + .map_err(|_| { + RpcError::internal_error() + .with_message("Failed to set server details".to_string()) + .into() + }) + } + + fn client_info(&self) -> &InitializeRequestParams { + &self.client_details + } + + fn server_info(&self) -> Option { + self.server_details_rx.borrow().clone() } async fn is_shut_down(&self) -> bool { - self.transport.is_shut_down().await + let result = self.is_shut_down.lock().await; + *result } + async fn shut_down(&self) -> SdkResult<()> { - self.transport.shut_down().await?; + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + let mut transport_map = self.transport_map.write().await; + let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + drop(transport_map); + for transport in transports { + let _ = transport.shut_down().await; + } // wait for tasks let mut tasks_lock = self.handlers.lock().await; @@ -344,4 +657,18 @@ impl McpClient for ClientRuntime { Ok(()) } + + async fn terminate_session(&self) { + #[cfg(feature = "streamable-http")] + { + if let Some(transport_options) = self.transport_options.as_ref() { + let session_id = self.session_id.read().await.clone(); + transport_options + .terminate_session(session_id.as_ref()) + .await; + let _ = self.shut_down().await; + } + } + let _ = self.shut_down().await; + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 7925f07..43a7079 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -8,7 +8,10 @@ use crate::schema::{ InitializeRequestParams, RpcError, ServerNotification, ServerRequest, }; use async_trait::async_trait; -use rust_mcp_transport::Transport; + +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; use crate::{ error::SdkResult, mcp_client::ClientHandler, mcp_traits::mcp_handler::McpClientHandler, @@ -37,10 +40,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -51,7 +54,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandler, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientInternalHandler::new(Box::new(handler))), )) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 8cb8cff..884de9d 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -1,5 +1,4 @@ -use std::sync::Arc; - +use super::ClientRuntime; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, @@ -7,17 +6,16 @@ use crate::schema::{ }, InitializeRequestParams, RpcError, }; -use async_trait::async_trait; - -use rust_mcp_transport::Transport; - use crate::{ error::SdkResult, mcp_handlers::mcp_client_handler_core::ClientHandlerCore, mcp_traits::{mcp_client::McpClient, mcp_handler::McpClientHandler}, }; - -use super::ClientRuntime; +use async_trait::async_trait; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; +use std::sync::Arc; /// Creates a new MCP client runtime with the specified configuration. /// @@ -39,10 +37,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -53,7 +51,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientCoreInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandlerCore, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientCoreInternalHandler::new(Box::new(handler))), )) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 57ba260..1b24b57 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -19,9 +19,11 @@ use futures::{StreamExt, TryFutureExt}; use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; use std::collections::HashMap; +use std::panic; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; + use tokio::sync::{mpsc, oneshot, watch}; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; @@ -46,7 +48,7 @@ pub struct ServerRuntime { server_details: Arc, #[cfg(feature = "hyper-server")] session_id: Option, - transport_map: tokio::sync::RwLock>, + transport_map: tokio::sync::RwLock>, //TODO: remove the transport_map, we do not need a hashmap for it request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, @@ -357,6 +359,9 @@ impl ServerRuntime { >, >, ) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("save transport for stream id : {}", stream_id); transport_map.insert(stream_id.to_string(), transport); @@ -364,34 +369,18 @@ impl ServerRuntime { } pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("removing transport for stream id : {}", stream_id); + if let Some(transport) = transport_map.get(stream_id) { + transport.shut_down().await?; + } transport_map.remove(stream_id); Ok(()) } - pub(crate) async fn transport_by_stream( - &self, - stream_id: &str, - ) -> SdkResult< - Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - > { - let transport_map = self.transport_map.read().await; - transport_map.get(stream_id).cloned().ok_or_else(|| { - RpcError::internal_error() - .with_message(format!("Transport for key {stream_id} not found")) - .into() - }) - } - pub(crate) async fn shutdown(&self) { let mut transport_map = self.transport_map.write().await; let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); @@ -403,17 +392,24 @@ impl ServerRuntime { pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool { let transport_map = self.transport_map.read().await; - transport_map.contains_key(stream_id) + let live_transport = if let Some(t) = transport_map.get(stream_id) { + !t.is_shut_down().await + } else { + false + }; + live_transport } pub(crate) async fn start_stream( self: Arc, - transport: impl TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, + transport: Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, >, stream_id: &str, ping_interval: Duration, @@ -421,10 +417,11 @@ impl ServerRuntime { ) -> SdkResult<()> { let mut stream = transport.start().await?; - self.store_transport(stream_id, Arc::new(transport)).await?; + if stream_id == DEFAULT_STREAM_ID { + self.store_transport(stream_id, transport.clone()).await?; + } let self_clone = self.clone(); - let transport = self_clone.transport_by_stream(stream_id).await?; let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); let abort_alive_task = transport @@ -439,7 +436,10 @@ impl ServerRuntime { // in case there is a payload, we consume it by transport to get processed if let Some(payload) = payload { - transport.consume_string_payload(&payload).await?; + if let Err(err) = transport.consume_string_payload(&payload).await { + let _ = self.remove_transport(stream_id).await; + return Err(err.into()); + } } // Create a channel to collect results from spawned tasks diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 5fbc43c..62fd31f 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -38,7 +38,7 @@ use crate::{ /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 5ed2239..110b20b 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -32,7 +32,7 @@ use std::sync::Arc; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 2b155fa..b66ba93 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -1,3 +1,4 @@ +pub(super) mod id_generator; #[cfg(feature = "client")] pub mod mcp_client; pub mod mcp_handler; @@ -5,4 +6,5 @@ pub mod mcp_handler; pub mod mcp_server; mod request_id_gen; +pub use id_generator::*; pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs new file mode 100644 index 0000000..e7cb8d3 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs @@ -0,0 +1,12 @@ +/// Trait for generating unique identifiers. +/// +/// This trait is generic over the target ID type, allowing it to be used for +/// generating different kinds of identifiers such as `SessionId` or +/// transport-scoped `StreamId`. +/// +pub trait IdGenerator: Send + Sync +where + T: From, +{ + fn generate(&self) -> T; +} diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 1883581..5fe3fba 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -1,9 +1,7 @@ -use std::{sync::Arc, time::Duration}; - use crate::schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, - NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, + ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient, + ResultFromServer, ServerMessage, }, CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams, CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation, @@ -17,21 +15,18 @@ use crate::schema::{ }; use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; -use rust_mcp_transport::{McpDispatch, MessageDispatcher}; +use std::{sync::Arc, time::Duration}; #[async_trait] pub trait McpClient: Sync + Send { async fn start(self: Arc) -> SdkResult<()>; fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>; + async fn terminate_session(&self); + async fn shut_down(&self) -> SdkResult<()>; async fn is_shut_down(&self) -> bool; - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch; - fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; @@ -170,48 +165,20 @@ pub trait McpClient: Sync + Send { &self, message: MessageFromClient, request_id: Option, - timeout: Option, + request_timeout: Option, ) -> SdkResult>; async fn send_batch( &self, messages: Vec, timeout: Option, - ) -> SdkResult>> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Batch(messages), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_batch()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>>; /// Sends a notification. This is a one-way message that is not expected /// to return any response. The method asynchronously sends the notification using /// the transport layer and does not wait for any acknowledgement or result. async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?; - - sender - .send_message(ClientMessages::Single(mcp_message), None) - .await?; + self.send(notification.into(), None, None).await?; Ok(()) } diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index e98a1ed..16fe7c7 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,6 +1,6 @@ use crate::schema::schema_utils::{ClientMessages, SdkError}; -use crate::error::{McpSdkError, SdkResult}; +use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; @@ -71,20 +71,20 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st /// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05"); /// assert!(result.is_ok()); /// -/// // Incompatible versions (client < server) +/// // Incompatible versions (requested < current) /// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2024_11_05" && server == "2025_03_26" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2024_11_05" && current == "2025_03_26" /// )); /// -/// // Incompatible versions (client > server) +/// // Incompatible versions (requested > current) /// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -93,10 +93,12 @@ pub fn ensure_server_protocole_compatibility( server_protocol_version: &str, ) -> SdkResult<()> { match client_protocol_version.cmp(server_protocol_version) { - Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(()), } } @@ -140,8 +142,8 @@ pub fn ensure_server_protocole_compatibility( /// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -151,10 +153,12 @@ pub fn enforce_compatible_protocol_version( ) -> SdkResult> { match client_protocol_version.cmp(server_protocol_version) { // if client protocol version is higher - Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(None), Ordering::Less => { // return the same version that was received from the client @@ -164,7 +168,10 @@ pub fn enforce_compatible_protocol_version( } pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> { - let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?; + let _mcp_protocol_version = + ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol { + kind: ProtocolErrorKind::ParseError(err), + })?; Ok(()) } diff --git a/crates/rust-mcp-sdk/tests/check_imports.rs b/crates/rust-mcp-sdk/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-sdk/tests/check_imports.rs +++ b/crates/rust-mcp-sdk/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 564db0d..f330dda 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -1,5 +1,8 @@ +mod mock_server; +mod test_client; mod test_server; use async_trait::async_trait; +pub use mock_server::*; use reqwest::{Client, Response, Url}; use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; @@ -8,9 +11,12 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::time::timeout; use tokio_stream::StreamExt; +use wiremock::{MockServer, Request, ResponseTemplate}; +pub use test_client::*; pub use test_server::*; pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything"; @@ -337,3 +343,52 @@ pub mod sample_tools { } } } + +pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { + let requests = mock_server.received_requests().await.unwrap(); + requests[index].clone() +} + +pub async fn debug_wiremock(mock_server: &MockServer) { + let requests = mock_server.received_requests().await.unwrap(); + let len = requests.len(); + println!(">>> {len} request(s) received <<<"); + + for (index, request) in requests.iter().enumerate() { + println!("\n--- #{index} of {len} ---"); + println!("Method: {}", request.method); + println!("Path: {}", request.url.path()); + // println!("Headers: {:#?}", request.headers); + println!("---- headers ----"); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + let body_str = String::from_utf8_lossy(&request.body); + println!("Body: {body_str}\n"); + } +} + +pub fn create_sse_response(payload: &str) -> ResponseTemplate { + let sse_body = format!(r#"data: {}{}"#, payload, "\n\n"); + ResponseTemplate::new(200).set_body_raw(sse_body.into_bytes(), "text/event-stream") +} + +pub async fn wait_for_n_requests( + mock_server: &MockServer, + num_requests: usize, + duration: Option, +) { + let duration = duration.unwrap_or(Duration::from_secs(1)); + timeout(duration, async { + loop { + let requests = mock_server.received_requests().await.unwrap(); + if requests.len() >= num_requests { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) + .await + .unwrap(); +} diff --git a/crates/rust-mcp-sdk/tests/common/mock_server.rs b/crates/rust-mcp-sdk/tests/common/mock_server.rs new file mode 100644 index 0000000..f5b533a --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/mock_server.rs @@ -0,0 +1,528 @@ +use axum::{ + body::Body, + extract::Request, + http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Response, Sse, + }, + routing::any, + Router, +}; +use core::fmt; +use futures::stream; +use std::collections::VecDeque; +use std::{future::Future, net::SocketAddr, pin::Pin}; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; +use tokio::net::TcpListener; + +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl ToString for SseEvent { + fn to_string(&self) -> String { + let mut s = String::new(); + + if let Some(id) = &self.id { + s.push_str("id: "); + s.push_str(id); + s.push('\n'); + } + + if let Some(event) = &self.event { + s.push_str("event: "); + s.push_str(event); + s.push('\n'); + } + + if let Some(data) = &self.data { + // Convert bytes to string safely, fallback if invalid UTF-8 + for line in data.lines() { + s.push_str("data: "); + s.push_str(line); + s.push('\n'); + } + } + + s.push('\n'); // End of event + s + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self.data.as_ref(); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +// RequestRecord stores the history of incoming requests +#[derive(Clone, Debug)] +pub struct RequestRecord { + pub method: Method, + pub path: String, + pub headers: HeaderMap, + pub body: String, +} + +#[derive(Clone, Debug)] +pub struct ResponseRecord { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: String, +} + +// pub type BoxedStream = +// Pin> + Send>>; +// pub type BoxedSseResponse = Sse; + +// pub type AsyncResponseFn = +// Box Pin + Send>> + Send + Sync>; + +type AsyncResponseFn = + Box Pin + Send>> + Send + Sync>; + +// Mock defines a single mock response configuration +// #[derive(Clone)] +pub struct Mock { + method: Method, + path: String, + response: String, + response_func: Option, + header_map: HeaderMap, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +// MockBuilder is a factory for creating Mock instances +pub struct MockBuilder { + method: Method, + path: String, + response: String, + header_map: HeaderMap, + response_func: Option, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +impl MockBuilder { + fn new(method: Method, path: String, response: String, header_map: HeaderMap) -> Self { + Self { + method, + path, + response, + response_func: None, + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + fn new_with_func( + method: Method, + path: String, + response_func: AsyncResponseFn, + header_map: HeaderMap, + ) -> Self { + Self { + method, + path, + response: String::new(), + response_func: Some(response_func), + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + pub fn new_breakable_sse( + method: Method, + path: String, + repeating_message: SseEvent, + interval: Duration, + repeat: usize, + ) -> Self { + let message = Arc::new(repeating_message); + let interval = interval; + let max_repeats = repeat; + + let response_fn: AsyncResponseFn = Box::new({ + let message = Arc::clone(&message); + move || { + let message = Arc::clone(&message); + + Box::pin(async move { + // Construct SSE stream with 10 static messages using unfold + let message_stream = stream::unfold(0, move |count| { + let message = Arc::clone(&message); + + async move { + if count >= max_repeats { + return Some(( + Err(std::io::Error::other("Message limit reached")), + count, + )); + } + tokio::time::sleep(interval).await; + + Some(( + Ok(Event::default() + .data(message.data.clone().unwrap_or("".into())) + .id(message.id.clone().unwrap_or(format!("msg-id_{count}"))) + .event(message.event.clone().unwrap_or("message".into()))), + count + 1, + )) + } + }); + + let sse_stream = Sse::new(message_stream) + .keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); + + sse_stream.into_response() + }) + } + }); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + Self::new_with_func(method, path, response_fn, header_map) + } + + pub fn with_matcher(mut self, matcher: F) -> Self + where + F: Fn(&str, &HeaderMap) -> bool + Send + Sync + 'static, + { + self.matcher = Some(Arc::new(matcher)); + self + } + + pub fn add_header(mut self, key: HeaderName, val: HeaderValue) -> Self { + self.header_map.insert(key, val); + self + } + + pub fn without_matcher(mut self) -> Self { + self.matcher = None; + self + } + + pub fn expect(mut self, num_calls: usize) -> Self { + self.remaining_calls = Some(Arc::new(Mutex::new(num_calls))); + self + } + + pub fn unlimited_calls(mut self) -> Self { + self.remaining_calls = None; + self + } + + pub fn with_status(mut self, status: StatusCode) -> Self { + self.status = status; + self + } + + pub fn build(self) -> Mock { + Mock { + method: self.method, + path: self.path, + response: self.response, + header_map: self.header_map, + matcher: self.matcher, + remaining_calls: self.remaining_calls, + status: self.status, + response_func: self.response_func, + } + } + + // add_string with text/plain + pub fn new_text(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + + Self::new(method, path, response.into(), header_map) + } + + /** + MockBuilder::new_response( + Method::GET, + "/mcp".to_string(), + Box::new(|| { + // tokio::time::sleep(Duration::from_secs(1)).await; + let json_response = Json(json!({ + "status": "ok", + "data": [1, 2, 3] + })) + .into_response(); + Box::pin(async move { json_response }) + }), + ) + .build(), + */ + pub fn new_response(method: Method, path: String, response_func: AsyncResponseFn) -> Self { + Self::new_with_func(method, path, response_func, HeaderMap::new()) + } + + // new_json with application/json + pub fn new_json(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + Self::new(method, path, response.into(), header_map) + } + + // new_sse with text/event-stream + pub fn new_sse(method: Method, path: String, response: impl Into) -> Self { + let response = format!(r#"data: {}{}"#, response.into(), '\n'); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + // ensure message ends with a \n\n , if needed + let cr = if response.ends_with("\n\n") { + "" + } else { + "\n\n" + }; + Self::new(method, path, format!("{response}{cr}"), header_map) + } + + // new_raw with application/octet-stream + pub fn new_raw(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Self::new(method, path, response.into(), header_map) + } +} + +// MockServerHandle provides access to the request history after the server starts +pub struct MockServerHandle { + history: Arc>>, +} + +impl MockServerHandle { + pub async fn get_history(&self) -> Vec<(RequestRecord, ResponseRecord)> { + let history = self.history.lock().unwrap(); + history.iter().cloned().collect() + } + + pub async fn print(&self) { + let requests = self.get_history().await; + + let len = requests.len(); + println!("\n>>> {len} request(s) received <<<"); + + for (index, (request, response)) in requests.iter().enumerate() { + println!( + "\n--- Request {} of {len} ------------------------------------", + index + 1 + ); + println!("Method: {}", request.method); + println!("Path: {}", request.path); + // println!("Headers: {:#?}", request.headers); + println!("> headers "); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + println!("\n> Body"); + println!("{}\n", &request.body); + + println!(">>>>> Response <<<<<"); + println!("> status: {}", response.status); + println!("> headers"); + for (key, values) in &response.headers { + println!("{key}: {values:?}"); + } + println!("> Body"); + println!("{}", &response.body); + } + } +} + +// MockServer is the main struct for configuring and starting the mock server +pub struct SimpleMockServer { + mocks: Vec, + history: Arc>>, +} + +impl Default for SimpleMockServer { + fn default() -> Self { + Self::new() + } +} + +impl SimpleMockServer { + pub fn new() -> Self { + Self { + mocks: Vec::new(), + history: Arc::new(Mutex::new(VecDeque::new())), + } + } + + pub async fn start_with_mocks(mocks: Vec) -> (String, MockServerHandle) { + let mut server = SimpleMockServer::new(); + server.add_mocks(mocks); + server.start().await + } + + // Generic add function + pub fn add_mock_builder(&mut self, builder: MockBuilder) -> &mut Self { + self.mocks.push(builder.build()); + self + } + + pub fn add_mock(&mut self, mock: Mock) -> &mut Self { + self.mocks.push(mock); + self + } + + pub fn add_mocks(&mut self, mock: Vec) -> &mut Self { + mock.into_iter().for_each(|m| self.mocks.push(m)); + self + } + + pub async fn start(self) -> (String, MockServerHandle) { + let mocks = Arc::new(self.mocks); + let history = Arc::clone(&self.history); + + async fn handler( + mocks: Arc>, + history: Arc>>, + mut req: Request, + ) -> impl IntoResponse { + // Take ownership of the body using std::mem::take + let body = std::mem::take(req.body_mut()); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + + let request_record = RequestRecord { + method: req.method().clone(), + path: req.uri().path().to_string(), + headers: req.headers().clone(), + body: body_str.clone(), + }; + + for m in mocks.iter() { + if m.method != *req.method() || m.path != req.uri().path() { + continue; + } + + if let Some(matcher) = &m.matcher { + if !(matcher)(&body_str, req.headers()) { + continue; + } + } + + if let Some(remaining) = &m.remaining_calls { + let mut rem = remaining.lock().unwrap(); + if *rem == 0 { + continue; + } + *rem -= 1; + } + + let mut resp = match m.response_func.as_ref() { + Some(get_response) => get_response().await.into_response(), + None => Response::new(Body::from(m.response.clone())), + }; + + // if let Some(resp_box) = &mut m.response_func.take() { + // let response = resp_box.into_response(); + // // *response.status_mut() = m.status; + // // m.response_func = Some(Box::new(response)); + // } + + // let mut resp = m.response_func.as_ref().unwrap().clone().to_owned(); + // let resp = *resp; + // *resp.into_response().status_mut() = m.status; + + // let mut response = m.response_func.as_ref().unwrap().clone(); + // let mut response = m.response_func.as_ref().unwrap().clone().to_owned(); + // let mut m = *response; + // *response.status_mut() = m.status; + // let resp = &*m.response_func.as_ref().unwrap().to_owned().clone().deref(); + + // let response = boxed_response.into_response(); + + // let mut resp = Response::new(Body::from(m.response.clone())); + *resp.status_mut() = m.status; + m.header_map.iter().for_each(|(k, v)| { + resp.headers_mut().insert(k, v.clone()); + }); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: m.response.clone(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + return resp; + } + + let resp = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: "".into(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + resp + } + + let app = Router::new().route( + "/{*path}", + any(move |req: Request| handler(Arc::clone(&mocks), Arc::clone(&history), req)), + ); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let url = format!("http://{local_addr}/"); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + ( + url, + MockServerHandle { + history: self.history, + }, + ) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs new file mode 100644 index 0000000..46a8525 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -0,0 +1,163 @@ +use async_trait::async_trait; +use rust_mcp_schema::{schema_utils::MessageFromServer, PingRequest, RpcError}; +use rust_mcp_sdk::{mcp_client::ClientHandler, McpClient}; +use serde_json::json; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[cfg(feature = "hyper-server")] +pub mod test_client_common { + use rust_mcp_schema::{ + schema_utils::MessageFromServer, ClientCapabilities, Implementation, + InitializeRequestParams, LATEST_PROTOCOL_VERSION, + }; + use rust_mcp_sdk::{ + mcp_client::{client_runtime, ClientRuntime}, + McpClient, RequestOptions, SessionId, StreamableTransportOptions, + }; + use std::{collections::HashMap, sync::Arc, time::Duration}; + use tokio::sync::RwLock; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + use wiremock::{ + matchers::{body_json_string, method, path}, + Mock, MockServer, ResponseTemplate, + }; + + use crate::common::{ + create_sse_response, test_server_common::INITIALIZE_RESPONSE, wait_for_n_requests, + }; + + pub struct InitializedClient { + pub client: Arc, + pub mcp_url: String, + pub mock_server: MockServer, + } + + pub const TEST_SESSION_ID: &str = "test-session-id"; + pub const INITIALIZE_REQUEST: &str = r#"{"id":0,"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{},"clientInfo":{"name":"simple-rust-mcp-client-sse","title":"Simple Rust MCP Client (SSE)","version":"0.1.0"},"protocolVersion":"2025-06-18"}}"#; + + pub fn test_client_details() -> InitializeRequestParams { + InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + } + } + + pub async fn create_client( + mcp_url: &str, + custom_headers: Option>, + ) -> (Arc, Arc>>) { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let client_details: InitializeRequestParams = test_client_details(); + + let transport_options = StreamableTransportOptions { + mcp_url: mcp_url.to_string(), + request_options: RequestOptions { + request_timeout: Duration::from_secs(2), + custom_headers, + ..RequestOptions::default() + }, + }; + + let message_history = Arc::new(RwLock::new(vec![])); + let handler = super::TestClientHandler { + message_history: message_history.clone(), + }; + + let client = + client_runtime::with_transport_options(client_details, transport_options, handler); + + // client.clone().start().await.unwrap(); + (client, message_history) + } + + pub async fn initialize_client( + session_id: Option, + custom_headers: Option>, + ) -> InitializedClient { + let mock_server = MockServer::start().await; + + // initialize response + let mut response = create_sse_response(INITIALIZE_RESPONSE); + + if let Some(session_id) = session_id { + response = response.append_header("mcp-session-id", session_id.as_str()); + } + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, custom_headers).await; + + client.clone().start().await.unwrap(); + + wait_for_n_requests(&mock_server, 2, None).await; + + InitializedClient { + client, + mcp_url, + mock_server, + } + } +} + +// Custom responder for SSE with 10 ping messages +struct SsePingResponder; + +// Test handler +pub struct TestClientHandler { + message_history: Arc>>, +} + +impl TestClientHandler { + async fn register_message(&self, message: &MessageFromServer) { + let mut lock = self.message_history.write().await; + lock.push(message.clone()); + } +} + +#[async_trait] +impl ClientHandler for TestClientHandler { + async fn handle_ping_request( + &self, + request: PingRequest, + runtime: &dyn McpClient, + ) -> std::result::Result { + self.register_message(&request.into()).await; + + Ok(rust_mcp_schema::Result { + meta: Some(json!({"meta_number":1515}).as_object().unwrap().to_owned()), + extra: None, + }) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 176e0d2..769f8c6 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,30 +1,30 @@ #[cfg(feature = "hyper-server")] pub mod test_server_common { + use crate::common::sample_tools::SayHelloTool; use async_trait::async_trait; use rust_mcp_schema::schema_utils::CallToolError; use rust_mcp_schema::{ CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; - use tokio_stream::StreamExt; - use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequest, InitializeRequestParams, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ - mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, + mcp_server::{hyper_server, HyperServer, HyperServerOptions, ServerHandler}, McpServer, SessionId, }; use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::timeout; - - use crate::common::sample_tools::SayHelloTool; + use tokio_stream::StreamExt; pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + pub const INITIALIZE_RESPONSE: &str = r#"{"result":{"protocolVersion":"2025-06-18","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{}},"serverInfo":{"name":"example-servers/everything","version":"1.0.0"}},"jsonrpc":"2.0","id":0}"#; pub struct LaunchedServer { pub hyper_runtime: HyperRuntime, @@ -150,14 +150,17 @@ pub mod test_server_common { } } - impl IdGenerator for TestIdGenerator { - fn generate(&self) -> SessionId { + impl IdGenerator for TestIdGenerator + where + T: From, + { + fn generate(&self) -> T { let mut lock = self.generated.write().unwrap(); *lock += 1; if *lock > self.constant_ids.len() { *lock = 1; } - self.constant_ids[*lock - 1].to_owned() + T::from(self.constant_ids[*lock - 1].to_owned()) } } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs new file mode 100644 index 0000000..cb82ff5 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -0,0 +1,823 @@ +#[path = "common/common.rs"] +pub mod common; + +use common::test_client_common::create_client; +use hyper::{Method, StatusCode}; +use rust_mcp_schema::{ + schema_utils::{ + ClientJsonrpcRequest, ClientMessage, MessageFromServer, RequestFromClient, + RequestFromServer, ResultFromServer, RpcMessage, ServerMessage, + }, + RequestId, ServerRequest, ServerResult, +}; +use rust_mcp_sdk::{ + error::McpSdkError, mcp_server::HyperServerOptions, McpClient, TransportError, + MCP_LAST_EVENT_ID_HEADER, +}; +use serde_json::{json, Value}; +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; +use wiremock::{ + http::{HeaderName, HeaderValue}, + matchers::{body_json_string, header, method, path}, + Mock, MockServer, ResponseTemplate, +}; + +use crate::common::{ + create_sse_response, debug_wiremock, random_port, + test_client_common::{ + initialize_client, InitializedClient, INITIALIZE_REQUEST, TEST_SESSION_ID, + }, + test_server_common::{ + create_start_server, LaunchedServer, TestIdGenerator, INITIALIZE_RESPONSE, + }, + wait_for_n_requests, wiremock_request, MockBuilder, SimpleMockServer, SseEvent, +}; + +// should send JSON-RPC messages via POST +#[tokio::test] +async fn should_send_json_rpc_messages_via_post() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should send batch messages +#[tokio::test] +async fn should_send_batch_messages() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(result + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + // debug_wiremock(&mock_server).await; +} + +// should store session ID received during initialization +#[tokio::test] +async fn should_store_session_id_received_during_initialization() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = + create_sse_response(INITIALIZE_RESPONSE).append_header("mcp-session-id", "test-session-id"); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should terminate session with DELETE request +#[tokio::test] +async fn should_terminate_session_with_delete_request() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 405 response when server doesn't support session termination +#[tokio::test] +async fn should_handle_405_unsupported_session_termination() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(405)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 404 response when session expires +#[tokio::test] +async fn should_handle_404_response_when_session_expires() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(404)) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + matches!( + result, + Err(McpSdkError::Transport(TransportError::SessionExpired)) + ); +} + +// should handle non-streaming JSON response +#[tokio::test] +async fn should_handle_non_streaming_json_response() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + let response = ResponseTemplate::new(200) + .set_body_json(json!({ + "id":1,"jsonrpc":"2.0", "result":{"something":"good"} + })) + .insert_header("Content-Type", "application/json"); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + let request = RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})); + + let result = client.request(request, None).await.unwrap(); + + let ResultFromServer::ServerResult(ServerResult::Result(result)) = result else { + panic!("Wrong result variant!") + }; + + let extra = result.extra.unwrap(); + assert_eq!(extra.get("something").unwrap(), "good"); +} + +// should handle successful initial GET connection for SSE +#[tokio::test] +async fn should_handle_successful_initial_get_connection_for_sse() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str(&"data: Connection established\n\n".to_string()); + + let response = ResponseTemplate::new(200) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()) +} + +#[tokio::test] +async fn should_receive_server_initiated_messaged() { + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + enable_json_response: Some(false), + ..Default::default() + }; + let LaunchedServer { + hyper_runtime, + streamable_url, + sse_url, + sse_message_url, + } = create_start_server(server_options).await; + + let (client, message_history) = create_client(&streamable_url, None).await; + + client.clone().start().await.unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let result = hyper_runtime + .ping(&"AAA-BBB-CCC".to_string(), None) + .await + .unwrap(); + + let lock = message_history.read().await; + let ping_request = lock + .iter() + .find(|m| { + matches!( + m, + MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_) + )) + ) + }) + .unwrap(); + let MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_), + )) = ping_request + else { + panic!("Request is not a match!") + }; + assert!(result.meta.is_some()); + + let v = result.meta.unwrap().get("meta_number").unwrap().clone(); + + assert!(matches!(v, Value::Number(value) if value.as_i64().unwrap()==1515)) //1515 is passed from TestClientHandler +} + +// should attempt initial GET connection and handle 405 gracefully +#[tokio::test] +async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(405)) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str(&"data: Connection established\n\n".to_string()); + + let response = ResponseTemplate::new(405) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()); + + // send a batch message, runtime should work as expected with no issue + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); +} + +// should handle multiple concurrent SSE streams +#[tokio::test] +async fn should_handle_multiple_concurrent_sse_streams() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(|req: &wiremock::Request| { + let body_string = String::from_utf8(req.body.clone()).unwrap(); + if body_string.contains("test3") { + create_sse_response(r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#) + } else { + create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ) + } + }) + .expect(2) + .mount(&mock_server) + .await; + + let message_3 = RequestFromClient::CustomRequest(json!({"method": "test3", "params": {}})); + let request1 = client.send_batch(vec![message_1, message_2], None); + let request2 = client.send(message_3.into(), None, None); + + // Run them concurrently and wait for both + let (res_batch, res_single) = tokio::join!(request1, request2); + + let res_batch = res_batch.unwrap().unwrap(); + // two results for two requests in the batch + assert_eq!(res_batch.len(), 2); + assert!(res_batch.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(res_batch + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + let res_single = res_single.unwrap().unwrap(); + let ServerMessage::Response(res_single) = res_single else { + panic!("invalid respinse type, expected Result!") + }; + + assert!(matches!(res_single.id, RequestId::Integer(id) if id==1)); +} + +// should throw error when invalid content-type is received +#[tokio::test] +async fn should_throw_error_when_invalid_content_type_is_received() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":0,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "text/plain", + )) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + let Err(McpSdkError::Transport(TransportError::UnexpectedContentType(content_type))) = result + else { + panic!("Expected a TransportError::UnexpectedContentType error!"); + }; + + assert_eq!(content_type, "text/plain"); +} + +// should always send specified custom headers +#[tokio::test] +async fn should_always_send_specified_custom_headers() { + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, Some(headers)).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "application/json", + )) + .expect(1) + .mount(&mock_server) + .await; + + let _result = client.ping(None).await; + + let requests = mock_server.received_requests().await.unwrap(); + + assert_eq!(requests.len(), 4); + assert!(requests + .iter() + .all(|r| r.headers.get(&"X-Custom-Header".into()).unwrap().as_str() == "CustomValue")); + + debug_wiremock(&mock_server).await +} + +// should reconnect a GET-initiated notification stream that fails + +#[tokio::test] +async fn should_reconnect_a_get_initiated_notification_stream_that_fails() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // two GET Mock, each expects one call , first time it fails, second retry it succeeds + let response = ResponseTemplate::new(502) + .set_body_raw("".to_string().into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .up_to_n_times(1) + .mount(&mock_server) + .await; + + let response = ResponseTemplate::new(200) + .set_body_raw( + "data: Connection established\n\n".to_string().into_bytes(), + "text/event-stream", + ) + .append_header("Connection", "keep-alive"); + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); +} + +//****************** Resumability ****************** +// should pass lastEventId when reconnecting +#[tokio::test] +async fn should_pass_last_event_id_when_reconnecting() { + let msg = r#"{"jsonrpc":"2.0","method":"notifications/message","params":{"data":{},"level":"debug"}}"#; + + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE).build(), + MockBuilder::new_breakable_sse( + Method::GET, + "/mcp".to_string(), + SseEvent { + data: Some(msg.into()), + event: Some("message".to_string()), + id: None, + }, + Duration::from_millis(100), + 5, + ) + .expect(2) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + // give it time for re-connection + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + + let get_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::GET) + .collect(); + + // there should be more than one GET reueat, indicating reconnection + assert!(get_requests.len() > 1); + + let Some(last_get_request) = get_requests.last() else { + panic!("Unable to find last GET request!"); + }; + + let last_event_id = last_get_request + .0 + .headers + .get(axum::http::HeaderName::from_static( + MCP_LAST_EVENT_ID_HEADER, + )); + + // last-event-id should be sent + assert!( + matches!(last_event_id, Some(last_event_id) if last_event_id.to_str().unwrap().starts_with("msg-id")) + ); + + // custom headers should be passed for all GET requests + assert!(get_requests.iter().all(|r| r + .0 + .headers + .get(axum::http::HeaderName::from_str("X-Custom-Header").unwrap()) + .unwrap() + .to_str() + .unwrap() + == "CustomValue")); + + println!("last_event_id {:?} ", last_event_id.unwrap()); +} + +// should NOT reconnect a POST-initiated stream that fails +#[tokio::test] +async fn should_not_reconnect_a_post_initiated_stream_that_fails() { + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE) + .expect(1) + .build(), + MockBuilder::new_sse(Method::GET, "/mcp".to_string(), "".to_string()) + .with_status(StatusCode::METHOD_NOT_ALLOWED) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .expect(1) + .build(), + MockBuilder::new_breakable_sse( + Method::POST, + "/mcp".to_string(), + SseEvent { + data: Some("msg".to_string()), + event: None, + id: None, + }, + Duration::ZERO, + 0, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + let result = client.send_roots_list_changed(None).await; + + assert!(result.is_err()); + + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + let post_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::POST) + .collect(); + assert_eq!(post_requests.len(), 3); // initialize, initialized, root_list_changed +} + +//****************** Auth ****************** +// attempts auth flow on 401 during POST request +// invalidates all credentials on InvalidClientError during auth +// invalidates all credentials on UnauthorizedClientError during auth +//invalidates tokens on InvalidGrantError during auth + +//****************** Others ****************** +// custom fetch in auth code paths +// should support custom reconnection options +// uses custom fetch implementation if provided +// should have exponential backoff with configurable maxRetries diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs similarity index 99% rename from crates/rust-mcp-sdk/tests/test_streamable_http.rs rename to crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 23ca27f..4809d6d 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -8,13 +8,12 @@ use rust_mcp_schema::{ SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, ServerMessages, }, - CallToolRequest, CallToolRequestParams, ListPromptsRequestParams, ListRootsRequestParams, - ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId, - RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, + CallToolRequest, CallToolRequestParams, ListRootsResult, ListToolsRequest, LoggingLevel, + LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, + ServerRequest, ServerResult, }; use rust_mcp_sdk::mcp_server::HyperServerOptions; use serde_json::{json, Map, Value}; -use tokio_stream::StreamExt; use crate::common::{ random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index ec061bb..2f03580 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -42,10 +42,12 @@ workspace = true ### FEATURES ################################################################# [features] -default = ["stdio", "sse", "2025_06_18"] # Default features +default = ["stdio", "sse", "streamable-http", "2025_06_18"] # Default features stdio = [] sse = ["reqwest"] +streamable-http = ["reqwest"] + # enabled mcp protocol version 2025_06_18 2025_06_18 = ["rust-mcp-schema/2025_06_18", "rust-mcp-schema/schema_utils"] diff --git a/crates/rust-mcp-transport/README.md b/crates/rust-mcp-transport/README.md index 23b78bf..30cad83 100644 --- a/crates/rust-mcp-transport/README.md +++ b/crates/rust-mcp-transport/README.md @@ -14,7 +14,7 @@ let transport = StdioTransport::new(TransportOptions { timeout: 60_000 })?; ``` -Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) example for a complete demonstration. +Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) example for a complete demonstration. ### For MCP Client @@ -51,7 +51,7 @@ let transport = StdioTransport::create_with_server_launch( )?; ``` -Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) example for a complete demonstration. +Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) example for a complete demonstration. --- diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index f201aa0..8d55bd0 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -5,7 +5,7 @@ use crate::transport::Transport; use crate::utils::{ extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel, }; -use crate::{IoStream, McpDispatch, TransportOptions}; +use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions}; use async_trait::async_trait; use bytes::Bytes; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; @@ -13,8 +13,13 @@ use reqwest::Client; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; -use crate::schema::schema_utils::McpMessage; -use crate::schema::RequestId; +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; use std::cmp::Ordering; use std::collections::HashMap; use std::pin::Pin; @@ -25,7 +30,7 @@ use tokio::sync::{mpsc, oneshot, Mutex}; const DEFAULT_CHANNEL_CAPACITY: usize = 64; const DEFAULT_MAX_RETRY: usize = 5; -const DEFAULT_RETRY_TIME_SECONDS: u64 = 3; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; /// Configuration options for the Client SSE Transport @@ -102,10 +107,9 @@ where let base_url = match extract_origin(server_url) { Some(url) => url, None => { - let error_message = - format!("Failed to extract origin from server URL: {server_url}"); - tracing::error!(error_message); - return Err(TransportError::InvalidOptions(error_message)); + let message = format!("Failed to extract origin from server URL: {server_url}"); + tracing::error!(message); + return Err(TransportError::Configuration { message }); } }; @@ -145,12 +149,15 @@ where let mut header_map = HeaderMap::new(); for (key, value) in headers { - let header_name = key - .parse::() - .map_err(|e| TransportError::InvalidOptions(format!("Invalid header name: {e}")))?; - let header_value = HeaderValue::from_str(value).map_err(|e| { - TransportError::InvalidOptions(format!("Invalid header value: {e}")) - })?; + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; header_map.insert(header_name, header_value); } @@ -172,10 +179,12 @@ where } if let Some(endpoint_origin) = extract_origin(&endpoint) { if endpoint_origin.cmp(&self.base_url) != Ordering::Equal { - return Err(TransportError::InvalidOptions(format!( + return Err(TransportError::Configuration { + message: format!( "Endpoint origin does not match connection origin. expected: {} , received: {}", self.base_url, endpoint_origin - ))); + ), + }); } return Ok(endpoint); } @@ -284,8 +293,8 @@ where Some(data) => { // trim the trailing \n before making a request let body = String::from_utf8_lossy(&data).trim().to_string(); - if let Err(e) = http_post(&client_clone, &post_url, body, &custom_headers).await { - tracing::error!("Failed to POST message: {e:?}"); + if let Err(e) = http_post(&client_clone, &post_url, body,None, custom_headers.as_ref()).await { + tracing::error!("Failed to POST message: {e}"); } }, None => break, // Exit if channel is closed @@ -335,7 +344,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function for ClientSseTransport" .to_string(), )) @@ -346,7 +355,7 @@ where _: Duration, _: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(), )) } @@ -413,3 +422,55 @@ where pending_requests.remove(request_id) } } + +#[async_trait] +impl McpDispatch + for ClientSseTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientSseTransport +{ +} diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs new file mode 100644 index 0000000..c318649 --- /dev/null +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -0,0 +1,515 @@ +use crate::error::TransportError; +use crate::mcp_stream::MCPStream; + +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; +use crate::utils::{ + http_delete, http_post, CancellationTokenSource, ReadableChannel, StreamableHttpStream, + WritableChannel, +}; +use crate::{error::TransportResult, IoStream, McpDispatch, MessageDispatcher, Transport}; +use crate::{SessionId, TransportDispatcher, TransportOptions}; +use async_trait::async_trait; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use reqwest::Client; +use std::collections::HashMap; +use std::pin::Pin; +use std::{sync::Arc, time::Duration}; +use tokio::io::{BufReader, BufWriter}; +use tokio::sync::oneshot::Sender; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; + +const DEFAULT_CHANNEL_CAPACITY: usize = 64; +const DEFAULT_MAX_RETRY: usize = 5; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; +const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; + +pub struct StreamableTransportOptions { + pub mcp_url: String, + pub request_options: RequestOptions, +} + +impl StreamableTransportOptions { + pub async fn terminate_session(&self, session_id: Option<&SessionId>) { + let client = Client::new(); + match http_delete(&client, &self.mcp_url, session_id, None).await { + Ok(_) => {} + Err(TransportError::Http(status_code)) => { + tracing::info!("Session termination failed with status code {status_code}",); + } + Err(error) => { + tracing::info!("Session termination failed with error :{error}"); + } + }; + } +} + +pub struct RequestOptions { + pub request_timeout: Duration, + pub retry_delay: Option, + pub max_retries: Option, + pub custom_headers: Option>, +} + +impl Default for RequestOptions { + fn default() -> Self { + Self { + request_timeout: TransportOptions::default().timeout, + retry_delay: None, + max_retries: None, + custom_headers: None, + } + } +} + +pub struct ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + /// Optional cancellation token source for shutting down the transport + shutdown_source: tokio::sync::RwLock>, + /// Flag indicating if the transport is shut down + is_shut_down: Mutex, + /// Timeout duration for MCP messages + request_timeout: Duration, + /// HTTP client for making requests + client: Client, + /// URL for the SSE endpoint + mcp_server_url: String, + /// Delay between retry attempts + retry_delay: Duration, + /// Maximum number of retry attempts + max_retries: usize, + /// Optional custom HTTP headers + custom_headers: Option, + sse_task: tokio::sync::RwLock>>, + post_task: tokio::sync::RwLock>>, + message_sender: Arc>>>, + error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, + session_id: Arc>>, + standalone: bool, +} + +impl ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + pub fn new( + options: &StreamableTransportOptions, + session_id: Option, + standalone: bool, + ) -> TransportResult { + let client = Client::new(); + + let headers = match &options.request_options.custom_headers { + Some(h) => Some(Self::validate_headers(h)?), + None => None, + }; + + let mcp_server_url = options.mcp_url.to_owned(); + Ok(Self { + shutdown_source: tokio::sync::RwLock::new(None), + is_shut_down: Mutex::new(false), + request_timeout: options.request_options.request_timeout, + client, + mcp_server_url, + retry_delay: options + .request_options + .retry_delay + .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)), + max_retries: options + .request_options + .max_retries + .unwrap_or(DEFAULT_MAX_RETRY), + sse_task: tokio::sync::RwLock::new(None), + post_task: tokio::sync::RwLock::new(None), + custom_headers: headers, + message_sender: Arc::new(tokio::sync::RwLock::new(None)), + error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: Arc::new(tokio::sync::RwLock::new(session_id)), + standalone, + }) + } + + fn validate_headers(headers: &HashMap) -> TransportResult { + let mut header_map = HeaderMap::new(); + for (key, value) in headers { + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; + header_map.insert(header_name, header_value); + } + Ok(header_map) + } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Readable(error_stream)); + } +} + +#[async_trait] +impl Transport for ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + async fn start(&self) -> TransportResult> + where + MessageDispatcher: McpDispatch, + { + if self.standalone { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + let session_id = self.session_id.read().await.to_owned(); + + let sse_response = streamable_http + .make_standalone_stream_connection(&cancellation_token_sse, &custom_headers, None) + .await?; + + let sse_task_handle = tokio::spawn(async move { + if let Err(error) = streamable_http + .run_standalone(&cancellation_token_sse, &custom_headers, sse_response) + .await + { + if !matches!(error, TransportError::Cancelled(_)) { + tracing::warn!("{error}"); + } + } + }); + + let mut sse_task_lock = self.sse_task.write().await; + *sse_task_lock = Some(sse_task_handle); + + let post_url = self.mcp_server_url.clone(); + let client = self.client.clone(); + let custom_headers = self.custom_headers.clone(); + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some(data) => { + // trim the trailing \n before making a request + let payload = String::from_utf8_lossy(&data).trim().to_string(); + + if let Err(e) = http_post( + &client, + &post_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await{ + tracing::error!("Failed to POST message: {e}") + } + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create writable stream + let writable: Mutex>> = + Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx }))); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create( + readable, + writable, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + Ok(stream) + } else { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + // let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (write_tx, mut write_rx): ( + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + tokio::sync::mpsc::Receiver<( + String, + tokio::sync::oneshot::Sender>, + )>, + ) = tokio::sync::mpsc::channel(DEFAULT_CHANNEL_CAPACITY); // Buffer size as needed + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some((data, ack_tx)) => { + // trim the trailing \n before making a request + let payload = data.trim().to_string(); + let result = streamable_http.run(payload, &cancellation_token_sse, &custom_headers).await; + let _ = ack_tx.send(result);// Ignore error if receiver dropped + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create_with_ack( + readable, + write_tx, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) + } + } + + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ + } + + fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ + } + async fn shut_down(&self) -> TransportResult<()> { + // Trigger cancellation + let mut cancellation_lock = self.shutdown_source.write().await; + if let Some(source) = cancellation_lock.as_ref() { + source.cancel()?; + } + *cancellation_lock = None; // Clear cancellation_source + + // Mark as shut down + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + // Get task handle + let post_task = self.post_task.write().await.take(); + + // // Wait for tasks to complete with a timeout + let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS); + let shutdown_future = async { + if let Some(post_handle) = post_task { + let _ = post_handle.await; + } + Ok::<(), TransportError>(()) + }; + + tokio::select! { + result = shutdown_future => { + result // result of task completion + } + _ = tokio::time::sleep(timeout) => { + tracing::warn!("Shutdown timed out after {:?}", timeout); + Err(TransportError::ShutdownTimeout) + } + } + } + async fn is_shut_down(&self) -> bool { + let result = self.is_shut_down.lock().await; + *result + } + async fn consume_string_payload(&self, _: &str) -> TransportResult<()> { + Err(TransportError::Internal( + "Invalid invocation of consume_string_payload() function for ClientStreamableTransport" + .to_string(), + )) + } + + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } + + async fn keep_alive( + &self, + _: Duration, + _: oneshot::Sender<()>, + ) -> TransportResult> { + Err(TransportError::Internal( + "Invalid invocation of keep_alive() function for ClientStreamableTransport".to_string(), + )) + } + + async fn session_id(&self) -> Option { + let guard = self.session_id.read().await; + guard.clone() + } +} + +#[async_trait] +impl McpDispatch + for ClientStreamableTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientStreamableTransport +{ +} diff --git a/crates/rust-mcp-transport/src/constants.rs b/crates/rust-mcp-transport/src/constants.rs new file mode 100644 index 0000000..6ae0342 --- /dev/null +++ b/crates/rust-mcp-transport/src/constants.rs @@ -0,0 +1,3 @@ +pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; +pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; +pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id"; diff --git a/crates/rust-mcp-transport/src/error.rs b/crates/rust-mcp-transport/src/error.rs index 8f8b62f..a244456 100644 --- a/crates/rust-mcp-transport/src/error.rs +++ b/crates/rust-mcp-transport/src/error.rs @@ -1,11 +1,14 @@ use crate::schema::{schema_utils::SdkError, RpcError}; -use thiserror::Error; - use crate::utils::CancellationError; use core::fmt; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::Error as ReqwestError; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::StatusCode; use std::any::Any; +use std::io::Error as IoError; +use thiserror::Error; use tokio::sync::{broadcast, mpsc}; - /// A wrapper around a broadcast send error. This structure allows for generic error handling /// by boxing the underlying error into a type-erased form. #[derive(Debug)] @@ -80,31 +83,53 @@ pub type TransportResult = core::result::Result; #[derive(Debug, Error)] pub enum TransportError { - #[error("{0}")] - InvalidOptions(String), + #[error("Session expired or not found")] + SessionExpired, + + #[error("Failed to open SSE stream: {0}")] + FailedToOpenSSEStream(String), + + #[error("Unexpected content type: '{0}'")] + UnexpectedContentType(String), + + #[error("Failed to send message: {0}")] + SendFailure(String), + + #[error("I/O error: {0}")] + Io(#[from] IoError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP connection error: {0}")] + HttpConnection(#[from] ReqwestError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP error: {0}")] + Http(StatusCode), + + #[error("SDK error: {0}")] + Sdk(#[from] SdkError), + + #[error("Operation cancelled: {0}")] + Cancelled(#[from] CancellationError), + + #[error("Channel closed: {0}")] + ChannelClosed(#[from] tokio::sync::oneshot::error::RecvError), + + #[error("Configuration error: {message}")] + Configuration { message: String }, + #[error("{0}")] SendError(#[from] GenericSendError), - #[error("{0}")] - WatchSendError(#[from] GenericWatchSendError), - #[error("Send Error: {0}")] - StdioError(#[from] std::io::Error), + #[error("{0}")] JsonrpcError(#[from] RpcError), - #[error("{0}")] - SdkError(#[from] SdkError), - #[error("Process error{0}")] + + #[error("Process error: {0}")] ProcessError(String), - #[error("{0}")] - FromString(String), - #[error("{0}")] - OneshotRecvError(#[from] tokio::sync::oneshot::error::RecvError), - #[cfg(feature = "sse")] - #[error("{0}")] - SendMessageError(#[from] reqwest::Error), - #[error("Http Error: {0}")] - HttpError(u16), + + #[error("Internal error: {0}")] + Internal(String), + #[error("Shutdown timed out")] ShutdownTimeout, - #[error("Cancellation error : {0}")] - CancellationError(#[from] CancellationError), } diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 1634922..4a918db 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -1,25 +1,38 @@ // Copyright (c) 2025 mcp-rust-stack // Licensed under the MIT License. See LICENSE file for details. // Modifications to this file must be documented with a description of the changes made. + #[cfg(feature = "sse")] mod client_sse; +#[cfg(feature = "streamable-http")] +mod client_streamable_http; +mod constants; pub mod error; mod mcp_stream; mod message_dispatcher; mod schema; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod sse; +#[cfg(feature = "stdio")] mod stdio; mod transport; mod utils; #[cfg(feature = "sse")] pub use client_sse::*; +#[cfg(feature = "streamable-http")] +pub use client_streamable_http::*; +pub use constants::*; pub use message_dispatcher::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub use sse::*; +#[cfg(feature = "stdio")] pub use stdio::*; pub use transport::*; // Type alias for session identifier, represented as a String pub type SessionId = String; +// Type alias for stream identifier (that will be used at the transport scope), represented as a String +pub type StreamId = String; +// Type alias for event (MCP message) identifier, represented as a String +pub type EventId = String; diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 08bdc21..0b10918 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -57,6 +57,43 @@ impl MCPStream { (stream, sender, error_io) } + pub fn create_with_ack( + readable: Pin>, + writable: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + error_io: IoStream, + pending_requests: Arc>>>, + request_timeout: Duration, + cancellation_token: CancellationToken, + ) -> ( + tokio_stream::wrappers::ReceiverStream, + MessageDispatcher, + IoStream, + ) + where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + { + let (tx, rx) = tokio::sync::mpsc::channel::(CHANNEL_CAPACITY); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + // Clone cancellation_token for reader + let reader_token = cancellation_token.clone(); + + #[allow(clippy::let_underscore_future)] + let _ = Self::spawn_reader(readable, tx, reader_token); + + let sender = MessageDispatcher::new_with_acknowledgement( + pending_requests, + writable, + request_timeout, + ); + + (stream, sender, error_io) + } + /// Creates a new task that continuously reads from the readable stream. /// The received data is deserialized into a JsonrpcMessage. If the deserialization is successful, /// the object is transmitted. If the object is a response or error corresponding to a pending request, diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index ea1eb04..7c7c93e 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -29,7 +29,13 @@ use crate::McpDispatch; /// a configurable timeout mechanism for asynchronous responses. pub struct MessageDispatcher { pending_requests: Arc>>>, - writable_std: Mutex>>, + writable_std: Option>>>, + writable_tx: Option< + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + >, request_timeout: Duration, } @@ -51,7 +57,24 @@ impl MessageDispatcher { ) -> Self { Self { pending_requests, - writable_std, + writable_std: Some(writable_std), + writable_tx: None, + request_timeout, + } + } + + pub fn new_with_acknowledgement( + pending_requests: Arc>>>, + writable_tx: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + request_timeout: Duration, + ) -> Self { + Self { + pending_requests, + writable_tx: Some(writable_tx), + writable_std: None, request_timeout, } } @@ -125,7 +148,7 @@ impl McpDispatch match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { Ok(response) => Ok(Some(ServerMessages::Single(response))), Err(error) => match error { - TransportError::OneshotRecvError(_) => { + TransportError::ChannelClosed(_) => { Err(schema_utils::SdkError::connection_closed().into()) } _ => Err(error), @@ -147,6 +170,9 @@ impl McpDispatch }) .unzip(); + // Ensure all request IDs are stored before sending the request + let tasks = join_all(pending_tasks).await; + // send the batch messages to the server let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) @@ -154,12 +180,10 @@ impl McpDispatch self.write_str(message_payload.as_str()).await?; // no request in the batch, no need to wait for the result - if pending_tasks.is_empty() { + if request_ids.is_empty() { return Ok(None); } - let tasks = join_all(pending_tasks).await; - let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| { rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout))) }); @@ -210,11 +234,24 @@ impl McpDispatch /// appending a newline character and flushing the stream afterward. /// async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(format!("{err}")))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } @@ -339,10 +376,23 @@ impl McpDispatch /// appending a newline character and flushing the stream afterward. /// async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(err.to_string()))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 50dbb32..09809e4 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -156,7 +156,7 @@ impl Transport {} - Err(TransportError::StdioError(error)) => { + Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { let _ = disconnect_tx.send(()); break; diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 0b67d64..11bd0a6 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -1,5 +1,6 @@ use crate::schema::schema_utils::{ - ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, + ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage, + ServerMessages, }; use crate::schema::RequestId; use async_trait::async_trait; @@ -193,22 +194,22 @@ where #[cfg(unix)] command.process_group(0); - let mut process = command.spawn().map_err(TransportError::StdioError)?; + let mut process = command.spawn().map_err(TransportError::Io)?; let stdin = process .stdin .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?; let stdout = process .stdout .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?; let stderr = process .stderr .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?; let pending_requests_clone = self.pending_requests.clone(); @@ -274,7 +275,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(), )) } @@ -284,7 +285,7 @@ where _interval: Duration, _disconnect_tx: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for StdioTransport".to_string(), )) } @@ -364,3 +365,55 @@ impl > for StdioTransport { } + +#[async_trait] +impl McpDispatch + for StdioTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for StdioTransport +{ +} diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index 3d17ebd..b8e3ddc 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -1,15 +1,12 @@ -use std::{pin::Pin, sync::Arc, time::Duration}; - -use crate::schema::RequestId; +use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; +use crate::{schema::RequestId, SessionId}; use async_trait::async_trait; - +use std::{pin::Pin, sync::Arc, time::Duration}; use tokio::{ sync::oneshot::{self, Sender}, task::JoinHandle, }; -use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; - /// Default Timeout in milliseconds const DEFAULT_TIMEOUT_MSEC: u64 = 60_000; @@ -125,6 +122,9 @@ where interval: Duration, disconnect_tx: oneshot::Sender<()>, ) -> TransportResult>; + async fn session_id(&self) -> Option { + None + } } /// A composite trait that combines both transport and dispatch capabilities for the MCP protocol. @@ -160,3 +160,26 @@ where OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { } + +// pub trait IntoClientTransport { +// type TransportType: Transport< +// ServerMessages, +// MessageFromClient, +// ServerMessage, +// ClientMessages, +// ClientMessage, +// >; + +// fn into_transport(self, session_id: Option) -> TransportResult; +// } + +// impl IntoClientTransport for T +// where +// T: Transport, +// { +// type TransportType = T; + +// fn into_transport(self, _: Option) -> TransportResult { +// Ok(self) +// } +// } diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 218d517..82d7326 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,21 +1,29 @@ mod cancellation_token; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod http_utils; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod readable_channel; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +mod sse_parser; #[cfg(feature = "sse")] mod sse_stream; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +mod streamable_http_stream; +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; pub(crate) use cancellation_token::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +pub(crate) use sse_parser::*; #[cfg(feature = "sse")] pub(crate) use sse_stream::*; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +pub(crate) use streamable_http_stream::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; use crate::schema::schema_utils::SdkError; @@ -23,16 +31,16 @@ use tokio::time::{timeout, Duration}; use crate::error::{TransportError, TransportResult}; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] use crate::SessionId; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where F: std::future::Future>, // The operation returns a Result - E: Into, // The error type must be convertible to TransportError + E: Into, { match timeout(timeout_duration, operation).await { - Ok(result) => result.map_err(|err| err.into()), // Convert the error type into TransportError + Ok(result) => result.map_err(|err| err.into()), Err(_) => Err(SdkError::request_timeout(timeout_duration.as_millis()).into()), // Timeout error } } @@ -46,7 +54,7 @@ where /// # Returns /// A String containing the endpoint with the session ID added as a query parameter /// -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { // Handle empty endpoint let base = if endpoint.is_empty() { "/" } else { endpoint }; diff --git a/crates/rust-mcp-transport/src/utils/http_utils.rs b/crates/rust-mcp-transport/src/utils/http_utils.rs index 701dcb0..84b62dd 100644 --- a/crates/rust-mcp-transport/src/utils/http_utils.rs +++ b/crates/rust-mcp-transport/src/utils/http_utils.rs @@ -1,7 +1,35 @@ use crate::error::{TransportError, TransportResult}; +use crate::{SessionId, MCP_SESSION_ID_HEADER}; -use reqwest::header::{HeaderMap, CONTENT_TYPE}; -use reqwest::Client; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE}; +use reqwest::{Client, Response}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseType { + EventStream, + Json, +} + +/// Determines the response type based on the `Content-Type` header. +pub async fn validate_response_type(response: &Response) -> TransportResult { + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(content_type) => { + let content_type_str = content_type.to_str().map_err(|_| { + TransportError::UnexpectedContentType("".to_string()) + })?; + + // Normalize to lowercase for case-insensitive comparison + let content_type_normalized = content_type_str.to_ascii_lowercase(); + + match content_type_normalized.as_str() { + "text/event-stream" => Ok(ResponseType::EventStream), + "application/json" => Ok(ResponseType::Json), + other => Err(TransportError::UnexpectedContentType(other.to_string())), + } + } + None => Err(TransportError::UnexpectedContentType("".to_string())), + } +} /// Sends an HTTP POST request with the given body and headers /// @@ -17,21 +45,96 @@ pub async fn http_post( client: &Client, post_url: &str, body: String, - headers: &Option, -) -> TransportResult<()> { + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { let mut request = client .post(post_url) .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") .body(body); if let Some(map) = headers { request = request.headers(map.clone()); } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + let response = request.send().await?; if !response.status().is_success() { - return Err(TransportError::HttpError(response.status().as_u16())); + return Err(TransportError::Http(response.status())); } - Ok(()) + Ok(response) +} + +pub async fn http_get( + client: &Client, + url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .get(url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + return Err(TransportError::Http(response.status())); + } + Ok(response) +} + +pub async fn http_delete( + client: &Client, + post_url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .delete(post_url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status_code = response.status(); + return Err(TransportError::Http(status_code)); + } + Ok(response) +} + +#[allow(unused)] +pub fn get_header_value(response: &Response, header_name: HeaderName) -> Option { + let content_type = response.headers().get(header_name)?.to_str().ok()?; + Some(content_type.to_string()) } pub fn extract_origin(url: &str) -> Option { @@ -88,7 +191,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -113,11 +216,11 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is an HttpError with status 400 match result { - Err(TransportError::HttpError(status)) => assert_eq!(status, 400), + Err(TransportError::Http(status)) => assert_eq!(status, 400), _ => panic!("Expected HttpError with status 400"), } } @@ -142,7 +245,7 @@ mod tests { let headers = Some(create_test_headers()); // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -157,7 +260,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, url, body, &headers).await; + let result = http_post(&client, url, body, None, headers.as_ref()).await; // Assert the result is an error (likely a connection error) assert!(result.is_err()); diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs new file mode 100644 index 0000000..5933726 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -0,0 +1,320 @@ +use core::fmt; +use std::collections::HashMap; + +use bytes::{Bytes, BytesMut}; +const BUFFER_CAPACITY: usize = 1024; + +/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. +/// +/// Contains the event type, data payload, and optional event ID. +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl std::fmt::Display for SseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(id) = &self.id { + writeln!(f, "id: {id}")?; + } + + if let Some(event) = &self.event { + writeln!(f, "event: {event}")?; + } + + if let Some(data) = &self.data { + match std::str::from_utf8(data) { + Ok(text) => { + for line in text.lines() { + writeln!(f, "data: {line}")?; + } + } + Err(_) => { + writeln!(f, "data: [binary data]")?; + } + } + } + + writeln!(f)?; // Trailing newline for SSE message end + Ok(()) + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self + .data + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +/// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. +/// This Parser is specifically designed for MCP messages and with no multi-line data support +/// +/// This struct maintains a buffer to accumulate incoming data and parses it into SSE events +/// based on the SSE protocol. It handles fields like `event`, `data`, and `id` as defined +/// in the SSE specification. +#[derive(Debug)] +pub struct SseParser { + pub buffer: BytesMut, +} + +impl SseParser { + /// Creates a new `SseParser` with an empty buffer pre-allocated to a default capacity. + /// + /// The buffer is initialized with a capacity of `BUFFER_CAPACITY` to + /// optimize for typical SSE message sizes. + /// + /// # Returns + /// A new `SseParser` instance with an empty buffer. + pub fn new() -> Self { + Self { + buffer: BytesMut::with_capacity(BUFFER_CAPACITY), + } + } + + /// Processes a new chunk of bytes and parses it into a vector of `SseEvent`s. + /// + /// This method appends the incoming `bytes` to the internal buffer, splits it into + /// complete lines (delimited by `\n`), and parses each line according to the SSE + /// protocol. It supports `event`, `id`, and `data` fields, as well as comments + /// (lines starting with `:`). Empty lines are skipped, and incomplete lines remain + /// in the buffer for future processing. + /// + /// # Parameters + /// - `bytes`: The incoming chunk of bytes to parse. + /// + /// # Returns + /// A vector of `SseEvent`s parsed from the complete lines in the buffer. If no + /// complete events are found, an empty vector is returned. + pub fn process_new_chunk(&mut self, bytes: Bytes) -> Vec { + self.buffer.extend_from_slice(&bytes); + + // Collect complete lines (ending in \n)β€”keep ALL lines, including empty ones for \n\n detection + let mut lines = Vec::new(); + while let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') { + let line = self.buffer.split_to(pos + 1).freeze(); + lines.push(line); + } + + let mut events = Vec::new(); + let mut current_message_lines: Vec = Vec::new(); + + for line in lines { + current_message_lines.push(line); + + // Check if we've hit a double newline (end of message) + if current_message_lines.len() >= 2 + && current_message_lines + .last() + .is_some_and(|b| b.as_ref() == b"\n") + { + // Process the complete message (exclude the last empty lines for parsing) + let message_lines: Vec<_> = current_message_lines + .drain(..current_message_lines.len() - 1) + .filter(|l| l.as_ref() != b"\n") // Filter internal empties + .collect(); + + if let Some(event) = self.parse_sse_message(&message_lines) { + events.push(event); + } + } + } + + // Put back any incomplete message + if !current_message_lines.is_empty() { + self.buffer.clear(); + for line in current_message_lines { + self.buffer.extend_from_slice(&line); + } + } + + events + } + + fn parse_sse_message(&self, lines: &[Bytes]) -> Option { + let mut fields: HashMap = HashMap::new(); + let mut data_parts: Vec = Vec::new(); + + for line_bytes in lines { + let line_str = String::from_utf8_lossy(line_bytes); + + // Skip comments and empty lines + if line_str.is_empty() || line_str.starts_with(':') { + continue; + } + + let (key, value) = if let Some(value) = line_str.strip_prefix("data: ") { + ("data", value.trim_start().to_string()) + } else if let Some(value) = line_str.strip_prefix("event: ") { + ("event", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("id: ") { + ("id", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("retry: ") { + ("retry", value.trim().to_string()) + } else { + // Invalid line; skip + continue; + }; + + if key == "data" { + if !value.is_empty() { + data_parts.push(value); + } + } else { + fields.insert(key.to_string(), value); + } + } + + // Build data (concat multi-line data with \n) , should not occur in MCP tho + let data = if data_parts.is_empty() { + None + } else { + let full_data = data_parts.join("\n"); + Some(Bytes::copy_from_slice(full_data.as_bytes())) // Use copy_from_slice for efficiency + }; + + // Skip invalid message with no data + let data = data?; + + // Get event (default to None) + let event = fields.get("event").cloned(); + let id = fields.get("id").cloned(); + + Some(SseEvent { + event, + data: Some(data), + id, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn test_single_data_event() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + assert!(events[0].event.is_none()); + assert!(events[0].id.is_none()); + } + + #[test] + fn test_event_with_id_and_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("event: message\nid: 123\ndata: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_event_chunks_in_different_orders() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\nevent: message\nid: 123\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_comment_line_ignored() { + let mut parser = SseParser::new(); + let input = Bytes::from(": this is a comment\n\n"); + let events = parser.process_new_chunk(input); + assert_eq!(events.len(), 0); + } + + #[test] + fn test_event_with_empty_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("data:\n\n"); + let events = parser.process_new_chunk(input); + // Your parser skips data lines with empty content + assert_eq!(events.len(), 0); + } + + #[test] + fn test_partial_chunks() { + let mut parser = SseParser::new(); + + let part1 = Bytes::from("data: hello"); + let part2 = Bytes::from(" world\n\n"); + + let events1 = parser.process_new_chunk(part1); + assert_eq!(events1.len(), 0); // incomplete + + let events2 = parser.process_new_chunk(part2); + assert_eq!(events2.len(), 1); + assert_eq!( + events2[0].data.as_deref(), + Some(Bytes::from("hello world\n").as_ref()) + ); + } + + #[test] + fn test_malformed_lines() { + let mut parser = SseParser::new(); + let input = Bytes::from("something invalid\ndata: ok\n\n"); + + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("ok\n").as_ref()) + ); + } + + #[test] + fn test_multiple_events_in_one_chunk() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: first\n\ndata: second\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 2); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("first\n").as_ref()) + ); + assert_eq!( + events[1].data.as_deref(), + Some(Bytes::from("second\n").as_ref()) + ); + } +} diff --git a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs new file mode 100644 index 0000000..3362c71 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs @@ -0,0 +1,374 @@ +use super::CancellationToken; +use crate::error::{TransportError, TransportResult}; +use crate::utils::SseParser; +use crate::utils::{http_get, validate_response_type, ResponseType}; +use crate::{utils::http_post, MCP_SESSION_ID_HEADER}; +use crate::{EventId, MCP_LAST_EVENT_ID_HEADER}; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{Client, Response, StatusCode}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, RwLock}; +use tokio::time; +use tokio_stream::StreamExt; + +//-----------------------------------------------------------------------------------// +pub(crate) struct StreamableHttpStream { + /// HTTP client for making SSE requests + pub client: Client, + /// URL of the SSE endpoint + pub mcp_url: String, + /// Maximum number of retry attempts for failed connections + pub max_retries: usize, + /// Delay between retry attempts + pub retry_delay: Duration, + /// Sender for transmitting received data to the readable channel + pub read_tx: mpsc::Sender, + /// Session id will be received from the server in the http + pub session_id: Arc>>, +} + +impl StreamableHttpStream { + pub(crate) async fn run( + &mut self, + payload: String, + cancellation_token: &CancellationToken, + custom_headers: &Option, + ) -> TransportResult<()> { + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let session_id = self.session_id.read().await.clone(); + + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!( + "StreamableHttp cancelled before connection attempt {}", + payload + ); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + //TODO: simplify + let response = match http_post( + &self.client, + &self.mcp_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await + { + Ok(response) => { + // if session_id_clone.read().await.is_none() { + let session_id = response + .headers() + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let mut guard = self.session_id.write().await; + *guard = session_id; + response + } + + Err(error) => { + tracing::error!("Failed to connect to MCP endpoint: {error}"); + return Err(error); + } + }; + + // return if status code != 200 and no result is expected + if response.status() != StatusCode::OK { + return Ok(()); + } + + let response_type = validate_response_type(&response).await?; + + // Handle non-streaming JSON response + if response_type == ResponseType::Json { + return match response.bytes().await { + Ok(bytes) => { + // Send the message + self.read_tx.send(bytes).await.map_err(|_| { + tracing::error!("Readable stream closed, shutting down MCP task"); + TransportError::SendFailure( + "Failed to send message: channel closed or full".to_string(), + ) + })?; + + // Send the newline + self.read_tx + .send(Bytes::from_static(b"\n")) + .await + .map_err(|_| { + tracing::error!( + "Failed to send newline, channel may be closed or full" + ); + TransportError::SendFailure( + "Failed to send newline: channel closed or full".to_string(), + ) + })?; + + Ok(()) + } + Err(error) => Err(error.into()), + }; + } + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unlike SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + // break after receiving the message(s) + return Ok(()); + } + } + Err(error) => { + tracing::error!("Error reading stream: {error}"); + return Err(error.into()); + } + } + } + } + + pub(crate) async fn make_standalone_stream_connection( + &self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + last_event_id: Option, + ) -> TransportResult { + let mut retry_count = 0; + let session_id = self.session_id.read().await.clone(); + + let headers = if let Some(event_id) = last_event_id.as_ref() { + let mut headers = HeaderMap::new(); + if let Some(custom) = custom_headers { + headers.extend(custom.iter().map(|(k, v)| (k.clone(), v.clone()))); + } + if let Ok(event_id_value) = HeaderValue::from_str(event_id) { + headers.insert(MCP_LAST_EVENT_ID_HEADER, event_id_value); + } + &Some(headers) + } else { + custom_headers + }; + + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + match http_get( + &self.client, + &self.mcp_url, + session_id.as_ref(), + headers.as_ref(), + ) + .await + { + Ok(response) => { + let is_event_stream = validate_response_type(&response) + .await + .is_ok_and(|response_type| response_type == ResponseType::EventStream); + + if !is_event_stream { + let message = + "SSE stream response returned an unexpected Content-Type.".to_string(); + tracing::warn!("{message}"); + return Err(TransportError::FailedToOpenSSEStream(message)); + } + + return Ok(response); + } + + Err(error) => { + match error { + crate::error::TransportError::HttpConnection(_) => { + // A reqwest::Error happened, we do not return ans instead retry the operation + } + crate::error::TransportError::Http(status_code) => match status_code { + StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => { + return Err(crate::error::TransportError::FailedToOpenSSEStream( + format!("Not supported (code: {status_code})"), + )); + } + other => { + tracing::warn!( + "Failed to open SSE stream: {error} (code: {other})" + ); + } + }, + error => { + return Err(error); // return the error where the retry wont help + } + } + + if retry_count >= self.max_retries { + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error); + } + retry_count += 1; + time::sleep(self.retry_delay).await; + continue; + } + }; + } + } + + pub(crate) async fn run_standalone( + &mut self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + response: Response, + ) -> TransportResult<()> { + let mut retry_count = 0; + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let mut response = Some(response); + + // Main loop for reconnection attempts + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::debug!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + // use initially passed response, otherwise try to make a new sse connection + let response = match response.take() { + Some(response) => response, + None => { + tracing::debug!( + "Reconnecting to SSE stream... (try {} of {})", + retry_count, + self.max_retries + ); + self.make_standalone_stream_connection( + cancellation_token, + custom_headers, + _last_event_id.clone(), + ) + .await? + } + }; + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unlike SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + } + retry_count = 0; // Reset retry count on successful chunk + } + Err(error) => { + if retry_count >= self.max_retries { + tracing::error!("Error reading stream: {error}"); + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error.into()); + } + + tracing::debug!( + "The standalone SSE stream encountered an error: '{}'", + error + ); + retry_count += 1; + time::sleep(self.retry_delay).await; + break; // Break inner loop to reconnect + } + } + } + } + } +} diff --git a/crates/rust-mcp-transport/tests/check_imports.rs b/crates/rust-mcp-transport/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-transport/tests/check_imports.rs +++ b/crates/rust-mcp-transport/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/development.md b/development.md index e3673cc..e17dd17 100644 --- a/development.md +++ b/development.md @@ -33,14 +33,14 @@ Build and run instructions are available in their respective README.md files. You can run examples by passing the example project name to Cargo using the `-p` argument, like this: ```sh -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` -You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server`: +You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server-stdio`: ```sh -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` ## Code Formatting diff --git a/examples/hello-world-mcp-server-core/.gitignore b/examples/hello-world-mcp-server-stdio-core/.gitignore similarity index 100% rename from examples/hello-world-mcp-server-core/.gitignore rename to examples/hello-world-mcp-server-stdio-core/.gitignore diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-stdio-core/Cargo.toml similarity index 83% rename from examples/hello-world-mcp-server-core/Cargo.toml rename to examples/hello-world-mcp-server-stdio-core/Cargo.toml index bbab301..14eb904 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server-core" -version = "0.1.22" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server-core/README.md b/examples/hello-world-mcp-server-stdio-core/README.md similarity index 81% rename from examples/hello-world-mcp-server-core/README.md rename to examples/hello-world-mcp-server-stdio-core/README.md index af9d703..cf57884 100644 --- a/examples/hello-world-mcp-server-core/README.md +++ b/examples/hello-world-mcp-server-stdio-core/README.md @@ -23,14 +23,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server-core --release +cargo build -p hello-world-mcp-server-stdio-core --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-core` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio-core` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-core +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio-core ``` ``` @@ -41,4 +41,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-stdio-core/src/handler.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/handler.rs rename to examples/hello-world-mcp-server-stdio-core/src/handler.rs diff --git a/examples/hello-world-mcp-server-core/src/main.rs b/examples/hello-world-mcp-server-stdio-core/src/main.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/main.rs rename to examples/hello-world-mcp-server-stdio-core/src/main.rs diff --git a/examples/hello-world-mcp-server-core/src/tools.rs b/examples/hello-world-mcp-server-stdio-core/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/tools.rs rename to examples/hello-world-mcp-server-stdio-core/src/tools.rs diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server-stdio/Cargo.toml similarity index 85% rename from examples/hello-world-mcp-server/Cargo.toml rename to examples/hello-world-mcp-server-stdio/Cargo.toml index 63a54af..9d15be3 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server" -version = "0.1.31" +name = "hello-world-mcp-server-stdio" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,8 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", - "hyper-server", - "ssl", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server/README.md b/examples/hello-world-mcp-server-stdio/README.md similarity index 84% rename from examples/hello-world-mcp-server/README.md rename to examples/hello-world-mcp-server-stdio/README.md index 33a62af..9e0bdda 100644 --- a/examples/hello-world-mcp-server/README.md +++ b/examples/hello-world-mcp-server-stdio/README.md @@ -22,14 +22,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio ``` ``` @@ -40,4 +40,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server/src/handler.rs b/examples/hello-world-mcp-server-stdio/src/handler.rs similarity index 100% rename from examples/hello-world-mcp-server/src/handler.rs rename to examples/hello-world-mcp-server-stdio/src/handler.rs diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server-stdio/src/main.rs similarity index 100% rename from examples/hello-world-mcp-server/src/main.rs rename to examples/hello-world-mcp-server-stdio/src/main.rs diff --git a/examples/hello-world-mcp-server/src/tools.rs b/examples/hello-world-mcp-server-stdio/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server/src/tools.rs rename to examples/hello-world-mcp-server-stdio/src/tools.rs diff --git a/examples/hello-world-server-core-streamable-http/.gitignore b/examples/hello-world-server-streamable-http-core/.gitignore similarity index 100% rename from examples/hello-world-server-core-streamable-http/.gitignore rename to examples/hello-world-server-streamable-http-core/.gitignore diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http-core/Cargo.toml similarity index 84% rename from examples/hello-world-server-core-streamable-http/Cargo.toml rename to examples/hello-world-server-streamable-http-core/Cargo.toml index 99d1011..a762058 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-server-core-streamable-http" -version = "0.1.22" +name = "hello-world-server-streamable-http-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-core-streamable-http/README.md b/examples/hello-world-server-streamable-http-core/README.md similarity index 95% rename from examples/hello-world-server-core-streamable-http/README.md rename to examples/hello-world-server-streamable-http-core/README.md index cd37623..49af2c2 100644 --- a/examples/hello-world-server-core-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http-core/README.md @@ -37,7 +37,7 @@ cd rust-mcp-sdk 2. Build and start the server: ```bash -cargo run -p hello-world-server-core-streamable-http --release +cargo run -p hello-world-server-streamable-http-core --release ``` By default, both the Streamable HTTP and SSE endpoints are displayed in the terminal: @@ -65,4 +65,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http-core/src/handler.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/handler.rs rename to examples/hello-world-server-streamable-http-core/src/handler.rs diff --git a/examples/hello-world-server-core-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/main.rs rename to examples/hello-world-server-streamable-http-core/src/main.rs diff --git a/examples/hello-world-server-core-streamable-http/src/tools.rs b/examples/hello-world-server-streamable-http-core/src/tools.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/tools.rs rename to examples/hello-world-server-streamable-http-core/src/tools.rs diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index df4296d..17a87c8 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-streamable-http/README.md b/examples/hello-world-server-streamable-http/README.md index ac56a86..7e3f3b6 100644 --- a/examples/hello-world-server-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http/README.md @@ -66,4 +66,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index c4732d2..3939d86 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,14 +1,11 @@ -use std::sync::Arc; - +use crate::tools::GreetingTools; use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; - -use crate::tools::GreetingTools; - +use std::sync::Arc; // Custom Handler to handle MCP Messages pub struct MyServerHandler; diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-sse-core/Cargo.toml similarity index 88% rename from examples/simple-mcp-client-core-sse/Cargo.toml rename to examples/simple-mcp-client-sse-core/Cargo.toml index 0e32790..25dcd7d 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core-sse" -version = "0.1.22" +name = "simple-mcp-client-sse-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "sse", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core-sse/README.md b/examples/simple-mcp-client-sse-core/README.md similarity index 97% rename from examples/simple-mcp-client-core-sse/README.md rename to examples/simple-mcp-client-sse-core/README.md index e7e10d2..a0852fb 100644 --- a/examples/simple-mcp-client-core-sse/README.md +++ b/examples/simple-mcp-client-sse-core/README.md @@ -32,7 +32,7 @@ npx @modelcontextprotocol/server-everything sse 2. Open a new terminal and run the project with: ```bash -cargo run -p simple-mcp-client-core-sse +cargo run -p simple-mcp-client-sse-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-sse-core/src/handler.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/handler.rs rename to examples/simple-mcp-client-sse-core/src/handler.rs diff --git a/examples/simple-mcp-client-core-sse/src/inquiry_utils.rs b/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/inquiry_utils.rs rename to examples/simple-mcp-client-sse-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core-sse/src/main.rs b/examples/simple-mcp-client-sse-core/src/main.rs similarity index 99% rename from examples/simple-mcp-client-core-sse/src/main.rs rename to examples/simple-mcp-client-sse-core/src/main.rs index 459f9ba..be8279b 100644 --- a/examples/simple-mcp-client-core-sse/src/main.rs +++ b/examples/simple-mcp-client-sse-core/src/main.rs @@ -44,6 +44,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime_core::create_client(client_details, transport, handler); // STEP 5: start the MCP client diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 14fd96b..bf7174d 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -9,6 +9,8 @@ license = "MIT" [dependencies] rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", + "sse", + "streamable-http", "macros", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-sse/src/main.rs b/examples/simple-mcp-client-sse/src/main.rs index ce8850a..0a76caa 100644 --- a/examples/simple-mcp-client-sse/src/main.rs +++ b/examples/simple-mcp-client-sse/src/main.rs @@ -15,7 +15,9 @@ use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -const MCP_SERVER_URL: &str = "http://localhost:3001/sse"; +// Connect to a server started with the following command: +// npx @modelcontextprotocol/server-everything sse +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/sse"; #[tokio::main] async fn main() -> SdkResult<()> { @@ -44,6 +46,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime::create_client(client_details, transport, handler); // STEP 5: start the MCP client @@ -57,6 +60,7 @@ async fn main() -> SdkResult<()> { let utils = InquiryUtils { client: Arc::clone(&client), }; + // Display server information (name and version) utils.print_server_info(); @@ -78,8 +82,11 @@ async fn main() -> SdkResult<()> { // Call add tool, and print the result utils.call_add_tool(100, 25).await?; - // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + // // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client-stdio-core/Cargo.toml similarity index 85% rename from examples/simple-mcp-client/Cargo.toml rename to examples/simple-mcp-client-stdio-core/Cargo.toml index 9599c46..6d95cf6 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client" -version = "0.1.31" +name = "simple-mcp-client-stdio-core" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } @@ -21,5 +22,6 @@ futures = { workspace = true } thiserror = { workspace = true } colored = "3.0.0" + [lints] workspace = true diff --git a/examples/simple-mcp-client-core/README.md b/examples/simple-mcp-client-stdio-core/README.md similarity index 97% rename from examples/simple-mcp-client-core/README.md rename to examples/simple-mcp-client-stdio-core/README.md index 52d8074..f3258aa 100644 --- a/examples/simple-mcp-client-core/README.md +++ b/examples/simple-mcp-client-stdio-core/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client-core +cargo run -p simple-mcp-client-stdio-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-stdio-core/src/handler.rs similarity index 100% rename from examples/simple-mcp-client-core/src/handler.rs rename to examples/simple-mcp-client-stdio-core/src/handler.rs diff --git a/examples/simple-mcp-client-core/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core/src/main.rs b/examples/simple-mcp-client-stdio-core/src/main.rs similarity index 100% rename from examples/simple-mcp-client-core/src/main.rs rename to examples/simple-mcp-client-stdio-core/src/main.rs diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-stdio/Cargo.toml similarity index 87% rename from examples/simple-mcp-client-core/Cargo.toml rename to examples/simple-mcp-client-stdio/Cargo.toml index 0dacc2d..3597105 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core" -version = "0.1.31" +name = "simple-mcp-client-stdio" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } @@ -21,6 +22,5 @@ futures = { workspace = true } thiserror = { workspace = true } colored = "3.0.0" - [lints] workspace = true diff --git a/examples/simple-mcp-client/README.md b/examples/simple-mcp-client-stdio/README.md similarity index 97% rename from examples/simple-mcp-client/README.md rename to examples/simple-mcp-client-stdio/README.md index c56a933..be17f02 100644 --- a/examples/simple-mcp-client/README.md +++ b/examples/simple-mcp-client-stdio/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client/src/handler.rs b/examples/simple-mcp-client-stdio/src/handler.rs similarity index 100% rename from examples/simple-mcp-client/src/handler.rs rename to examples/simple-mcp-client-stdio/src/handler.rs diff --git a/examples/simple-mcp-client/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client/src/main.rs b/examples/simple-mcp-client-stdio/src/main.rs similarity index 100% rename from examples/simple-mcp-client/src/main.rs rename to examples/simple-mcp-client-stdio/src/main.rs diff --git a/examples/simple-mcp-client-streamable-http-core/Cargo.toml b/examples/simple-mcp-client-streamable-http-core/Cargo.toml new file mode 100644 index 0000000..68356e1 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http-core" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "macros", + "streamable-http", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http-core/README.md b/examples/simple-mcp-client-streamable-http-core/README.md new file mode 100644 index 0000000..a0852fb --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client Core (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse-core +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http-core/src/handler.rs b/examples/simple-mcp-client-streamable-http-core/src/handler.rs new file mode 100644 index 0000000..ab86e9e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/handler.rs @@ -0,0 +1,72 @@ +use async_trait::async_trait; +use rust_mcp_sdk::schema::{ + self, + schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, + RpcError, ServerRequest, +}; +use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; +pub struct MyClientHandler; + +// To check out a list of all the methods in the trait that you can override, take a look at +// https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs + +#[async_trait] +impl ClientHandlerCore for MyClientHandler { + async fn handle_request( + &self, + request: RequestFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result { + match request { + RequestFromServer::ServerRequest(server_request) => match server_request { + ServerRequest::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + ServerRequest::CreateMessageRequest(_create_message_request) => { + Err(RpcError::internal_error().with_message( + "CreateMessageRequest handler is not implemented".to_string(), + )) + } + ServerRequest::ListRootsRequest(_list_roots_request) => { + Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())) + } + ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + }, + RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + .with_message("CustomRequest handler is not implemented".to_string())), + } + } + + async fn handle_notification( + &self, + notification: NotificationFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) + } + + async fn handle_error( + &self, + _error: &RpcError, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/main.rs b/examples/simple-mcp-client-streamable-http-core/src/main.rs new file mode 100644 index 0000000..e1a5849 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/main.rs @@ -0,0 +1,95 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use inquiry_utils::InquiryUtils; +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime_core; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +// Assuming @modelcontextprotocol/server-everything is launched with streamableHttp argument and listening on port 3001 +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-core-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (Core,SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client + let client = + client_runtime_core::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + utils.client.set_logging_level(LoggingLevel::Debug).await?; + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} diff --git a/examples/simple-mcp-client-streamable-http/Cargo.toml b/examples/simple-mcp-client-streamable-http/Cargo.toml new file mode 100644 index 0000000..0638aab --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "streamable-http", + "macros", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http/README.md b/examples/simple-mcp-client-streamable-http/README.md new file mode 100644 index 0000000..5b4488e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http/src/handler.rs b/examples/simple-mcp-client-streamable-http/src/handler.rs new file mode 100644 index 0000000..19360f6 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/handler.rs @@ -0,0 +1,10 @@ +use async_trait::async_trait; +use rust_mcp_sdk::mcp_client::ClientHandler; + +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} diff --git a/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http/src/main.rs b/examples/simple-mcp-client-streamable-http/src/main.rs new file mode 100644 index 0000000..ab580db --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/main.rs @@ -0,0 +1,99 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +use crate::inquiry_utils::InquiryUtils; + +const MCP_SERVER_URL: &str = "http://127.0.0.1:8080/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} From 08742bb9636f81ee79eda4edc192b3b8ed4c7287 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:59:03 -0300 Subject: [PATCH 28/33] feat: event store support for resumability (#101) * Add Streamable HTTP Client and multiple refactoring and improvements * chore: typos * chore: update readme * feat: introduce event-store * chore: add event store to the app state * chore: refactor event store integration * chore: add tracing to inmemory store * chore: update examples to use event store * chore: improve flow * chore: replay mechanism * cleanup * test: add new test for event-store * chore: add tracing to tests * chore: add test * chore: refactor replaying logic * chore: cleanup * typo --- README.md | 9 +- crates/rust-mcp-sdk/README.md | 8 +- .../src/hyper_servers/app_state.rs | 4 + .../src/hyper_servers/routes/hyper_utils.rs | 125 ++++++-- .../routes/streamable_http_routes.rs | 9 +- .../rust-mcp-sdk/src/hyper_servers/server.rs | 8 +- .../src/mcp_runtimes/server_runtime.rs | 6 +- crates/rust-mcp-sdk/tests/common/common.rs | 114 ++++++-- .../rust-mcp-sdk/tests/common/test_server.rs | 4 + .../tests/test_streamable_http_client.rs | 1 + .../tests/test_streamable_http_server.rs | 218 ++++++++++++-- crates/rust-mcp-transport/src/client_sse.rs | 4 +- .../src/client_streamable_http.rs | 4 +- crates/rust-mcp-transport/src/event_store.rs | 27 ++ .../src/event_store/in_memory_event_store.rs | 274 ++++++++++++++++++ crates/rust-mcp-transport/src/lib.rs | 1 + .../src/message_dispatcher.rs | 83 +++++- crates/rust-mcp-transport/src/sse.rs | 43 ++- crates/rust-mcp-transport/src/stdio.rs | 8 +- crates/rust-mcp-transport/src/transport.rs | 2 +- crates/rust-mcp-transport/src/utils.rs | 2 + .../src/utils/time_utils.rs | 8 + .../src/main.rs | 4 + .../src/main.rs | 3 + 24 files changed, 863 insertions(+), 106 deletions(-) create mode 100644 crates/rust-mcp-transport/src/event_store.rs create mode 100644 crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs create mode 100644 crates/rust-mcp-transport/src/utils/time_utils.rs diff --git a/README.md b/README.md index c1e201c..51c3b49 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); @@ -191,7 +192,6 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { @@ -416,6 +416,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -432,6 +433,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -500,8 +505,8 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. -- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 8036022..51c3b49 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -153,6 +153,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); @@ -415,6 +416,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -431,6 +433,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -499,8 +505,8 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. -- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index ff6d5b2..e7f8793 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration}; use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; +use rust_mcp_transport::event_store::EventStore; use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server @@ -30,6 +31,9 @@ pub struct AppState { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, } impl AppState { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index da69c67..7101a73 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -23,7 +23,8 @@ use axum::{ use futures::stream; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ - SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, + EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; @@ -36,6 +37,7 @@ async fn create_sse_stream( state: Arc, payload: Option<&str>, standalone: bool, + last_event_id: Option, ) -> TransportServerResult> { let payload_string = payload.map(|p| p.to_string()); @@ -53,50 +55,85 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = Arc::new( - SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?, - ); - - let stream_id: StreamId = if standalone { - DEFAULT_STREAM_ID.to_string() + let session_id = Arc::new(session_id); + let stream_id: Arc = if standalone { + Arc::new(DEFAULT_STREAM_ID.to_string()) } else { - state.stream_id_gen.generate() + Arc::new(state.stream_id_gen.generate()) }; + + let event_store = state.event_store.as_ref().map(Arc::clone); + let resumability_enabled = event_store.is_some(); + + let mut transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + if let Some(event_store) = event_store.clone() { + transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); + } + let transport = Arc::new(transport); + let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); + let stream_id_clone = stream_id.clone(); + let transport_clone = transport.clone(); //Start the server runtime tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + transport_clone, + &stream_id_clone, + ping_interval, + payload_string, + ) .await { - Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), } - let _ = runtime.remove_transport(&stream_id).await; + let _ = runtime.remove_transport(&stream_id_clone).await; }); // Construct SSE stream let reader = BufReader::new(write_rx); - // outgoing messages from server to the client - let message_stream = stream::unfold(reader, |mut reader| async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some((Ok(Event::default().data(trimmed_line)), reader)) + // send outgoing messages from server to the client over the sse stream + let message_stream = stream::unfold(reader, move |mut reader| { + async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + + // empty sse comment to keep-alive + if is_empty_sse_message(&trimmed_line) { + return Some((Ok(Event::default()), reader)); + } + + let (event_id, message) = match ( + resumability_enabled, + trimmed_line.split_once(char::from(ID_SEPARATOR)), + ) { + (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), + _ => (None, trimmed_line), + }; + + let event = match event_id { + Some(id) => Event::default().data(message).id(id), + None => Event::default().data(message), + }; + + Some((Ok(event), reader)) + } + Err(e) => Some((Err(e), reader)), } - Err(e) => Some((Err(e), reader)), } }); @@ -111,6 +148,23 @@ async fn create_sse_stream( HeaderValue::from_str(&session_id).unwrap(), ); + // if last_event_id exists we replay messages from the event-store + tokio::spawn(async move { + if let Some(last_event_id) = last_event_id { + if let Some(event_store) = state.event_store.as_ref() { + if let Some(events) = event_store.events_after(last_event_id).await { + for message_payload in events.messages { + // skip storing replay messages + let error = transport.write_str(&message_payload, true).await; + if let Err(error) = error { + tracing::trace!("Error replaying message: {error}") + } + } + } + } + } + }); + if !payload_contains_request { *response.status_mut() = StatusCode::ACCEPTED; } @@ -148,6 +202,7 @@ fn is_result(json_str: &str) -> Result { pub async fn create_standalone_stream( session_id: SessionId, + last_event_id: Option, state: Arc, ) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( @@ -161,12 +216,20 @@ pub async fn create_standalone_stream( return Ok((StatusCode::CONFLICT, Json(error)).into_response()); } + if let Some(last_event_id) = last_event_id.as_ref() { + tracing::trace!( + "SSE stream re-connected with last-event-id: {}", + last_event_id + ); + } + let mut response = create_sse_stream( runtime.clone(), session_id.clone(), state.clone(), None, true, + last_event_id, ) .await?; *response.status_mut() = StatusCode::OK; @@ -195,6 +258,7 @@ pub async fn start_new_session( state.clone(), Some(payload), false, + None, ) .await; @@ -354,6 +418,7 @@ pub async fn process_incoming_message( state.clone(), Some(payload), false, + None, ) .await } @@ -365,6 +430,10 @@ pub async fn process_incoming_message( } } +pub fn is_empty_sse_message(sse_payload: &str) -> bool { + sse_payload.is_empty() || sse_payload.trim() == ":" +} + pub async fn delete_session( session_id: SessionId, state: Arc, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 00d46c0..67f8679 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -23,7 +23,7 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, MCP_SESSION_ID_HEADER}; +use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { @@ -60,9 +60,14 @@ pub async fn handle_streamable_http_get( .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); + let last_event_id: Option = headers + .get(MCP_LAST_EVENT_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + match session_id { Some(session_id) => { - let res = create_standalone_stream(session_id, state).await?; + let res = create_standalone_stream(session_id, last_event_id, state).await?; Ok(res.into_response()) } None => { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 1c3b3cf..71bccee 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -23,7 +23,7 @@ use super::{ }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::{SessionId, TransportOptions}; +use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -53,6 +53,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -225,6 +229,7 @@ impl Default for HyperServerOptions { allowed_hosts: None, allowed_origins: None, dns_rebinding_protection: false, + event_store: None, } } } @@ -271,6 +276,7 @@ impl HyperServer { allowed_hosts: server_options.allowed_hosts.take(), allowed_origins: server_options.allowed_origins.take(), dns_rebinding_protection: server_options.dns_rebinding_protection, + event_store: server_options.event_store.as_ref().map(Arc::clone), }); let app = app_routes(Arc::clone(&state), &server_options); Self { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 1b24b57..5502cee 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -368,16 +368,17 @@ impl ServerRuntime { Ok(()) } + //TODO: re-visit and simplify unnecessary hashmap pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { if stream_id != DEFAULT_STREAM_ID { return Ok(()); } - let mut transport_map = self.transport_map.write().await; + let transport_map = self.transport_map.read().await; tracing::trace!("removing transport for stream id : {}", stream_id); if let Some(transport) = transport_map.get(stream_id) { transport.shut_down().await?; } - transport_map.remove(stream_id); + // transport_map.remove(stream_id); Ok(()) } @@ -435,6 +436,7 @@ impl ServerRuntime { }; // in case there is a payload, we consume it by transport to get processed + // payload would be message payload coming from the client if let Some(payload) = payload { if let Err(err) = transport.consume_string_payload(&payload).await { let _ = self.remove_transport(stream_id).await; diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index f330dda..6b78895 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -11,9 +11,11 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; +use std::sync::Once; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::time::timeout; use tokio_stream::StreamExt; +use tracing_subscriber::EnvFilter; use wiremock::{MockServer, Request, ResponseTemplate}; pub use test_client::*; @@ -23,7 +25,17 @@ pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything #[cfg(unix)] pub const UVX_SERVER_GIT: &str = "mcp-server-git"; +static INIT: Once = Once::new(); +pub fn init_tracing() { + INIT.call_once(|| { + let filter = EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("tracing")) + .unwrap(); + + tracing_subscriber::fmt().with_env_filter(filter).init(); + }); +} #[mcp_tool( name = "say_hello", description = "Accepts a person's name and says a personalized \"Hello\" to that person", @@ -126,16 +138,18 @@ pub async fn send_get_request( ); } } + client.get(url).headers(headers).send().await } use futures::stream::Stream; // stream: &mut impl Stream>, +/// reads sse events and return them as (id, event, data) tuple pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), event_count: usize, -) -> Option> { +) -> Option, Option, String)>> { let mut buffer = String::new(); let mut events = vec![]; @@ -146,27 +160,28 @@ pub async fn read_sse_event_from_stream( buffer.push_str(chunk_str); while let Some(pos) = buffer.find("\n\n") { - let data = { - // Scope to limit borrows - let (event_str, rest) = buffer.split_at(pos); - let mut data = None; - - // Process the event string - for line in event_str.lines() { - if line.starts_with("data:") { - data = Some(line.trim_start_matches("data:").trim().to_string()); - break; // Exit loop after finding data - } + let (event_str, rest) = buffer.split_at(pos); + let mut id = None; + let mut event = None; + let mut data = None; + + // Process the event string + for line in event_str.lines() { + if line.starts_with("id:") { + id = Some(line.trim_start_matches("id:").trim().to_string()); + } else if line.starts_with("event:") { + event = Some(line.trim_start_matches("event:").trim().to_string()); + } else if line.starts_with("data:") { + data = Some(line.trim_start_matches("data:").trim().to_string()); } + } - // Update buffer after processing - buffer = rest[2..].to_string(); // Skip "\n\n" - data - }; + // Update buffer after processing + buffer = rest[2..].to_string(); // Skip "\n\n" - // Return if data was found + // Only include events with data if let Some(data) = data { - events.push(data); + events.push((id, event, data)); if events.len().eq(&event_count) { return Some(events); } @@ -174,17 +189,26 @@ pub async fn read_sse_event_from_stream( } } Err(_e) => { - // return Err(TransportServerError::HyperError(e)); return None; } } } - None + if !events.is_empty() { + Some(events) + } else { + None + } } -pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { +// return sse event as (id, event, data) tuple +pub async fn read_sse_event( + response: Response, + event_count: usize, +) -> Option, Option, String)>> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream, event_count).await + let events = read_sse_event_from_stream(&mut stream, event_count).await; + // drop(stream); + events } pub fn test_client_info() -> InitializeRequestParams { @@ -280,9 +304,16 @@ pub fn random_port_old() -> u16 { } pub mod sample_tools { + use std::{sync::Arc, time::Duration}; + + use rust_mcp_schema::{LoggingMessageNotificationParams, TextContent}; #[cfg(feature = "2025_06_18")] use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; - use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult}; + use rust_mcp_sdk::{ + schema::{schema_utils::CallToolError, CallToolResult}, + McpServer, + }; + use serde_json::json; //****************// // SayHelloTool // @@ -342,6 +373,43 @@ pub mod sample_tools { return Ok(CallToolResult::text_content(goodbye_message, None)); } } + + //****************************// + // StartNotificationStream // + //****************************// + #[mcp_tool( + name = "start-notification-stream", + description = "Accepts a person's name and says a personalized \"Goodbye\" to that person." + )] + #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] + pub struct StartNotificationStream { + /// Interval in milliseconds between notifications + interval: u64, + /// Number of notifications to send (0 for 100) + count: u32, + } + impl StartNotificationStream { + pub async fn call_tool( + &self, + runtime: Arc, + ) -> Result { + for i in 0..self.count { + let _ = runtime + .send_logging_message(LoggingMessageNotificationParams { + data: json!({"id":format!("message {} of {}",i,self.count)}), + level: rust_mcp_sdk::schema::LoggingLevel::Emergency, + logger: None, + }) + .await; + tokio::time::sleep(Duration::from_millis(self.interval)).await; + } + + let message = format!("so many messages sent"); + Ok(CallToolResult::text_content(vec![TextContent::from( + message, + )])) + } + } } pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 769f8c6..d64244b 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -7,6 +7,7 @@ pub mod test_server_common { CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::event_store::EventStore; use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; use rust_mcp_sdk::schema::{ @@ -31,6 +32,7 @@ pub mod test_server_common { pub streamable_url: String, pub sse_url: String, pub sse_message_url: String, + pub event_store: Option>, } pub fn initialize_request() -> InitializeRequest { @@ -120,6 +122,7 @@ pub mod test_server_common { let sse_url = options.sse_url(); let sse_message_url = options.sse_message_url(); + let event_store_clone = options.event_store.clone(); let server = hyper_server::create_server(test_server_details(), TestServerHandler {}, options); @@ -132,6 +135,7 @@ pub mod test_server_common { streamable_url, sse_url, sse_message_url, + event_store: event_store_clone, } } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index cb82ff5..1d273e5 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -350,6 +350,7 @@ async fn should_receive_server_initiated_messaged() { streamable_url, sse_url, sse_message_url, + event_store, } = create_start_server(server_options).await; let (client, message_history) = create_client(&streamable_url, None).await; diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 4809d6d..af2dce6 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -12,7 +12,7 @@ use rust_mcp_schema::{ LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, }; -use rust_mcp_sdk::mcp_server::HyperServerOptions; +use rust_mcp_sdk::{event_store::InMemoryEventStore, mcp_server::HyperServerOptions}; use serde_json::{json, Map, Value}; use crate::common::{ @@ -40,6 +40,8 @@ async fn initialize_server( "AAA-BBB-CCC".to_string() ]))), enable_json_response, + ping_interval: Duration::from_secs(1), + event_store: Some(Arc::new(InMemoryEventStore::default())), ..Default::default() }; @@ -169,7 +171,7 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,7 +222,7 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -290,12 +292,20 @@ async fn should_reject_invalid_session_id() { server.hyper_runtime.await_server().await.unwrap() } -async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwest::Response { +async fn get_standalone_stream( + streamable_url: &str, + session_id: &str, + last_event_id: Option<&str>, +) -> reqwest::Response { let mut headers = HashMap::new(); headers.insert("Accept", "text/event-stream , application/json"); headers.insert("mcp-session-id", session_id); headers.insert("mcp-protocol-version", "2025-03-26"); + if let Some(last_event_id) = last_event_id.clone() { + headers.insert("last-event-id", last_event_id); + } + let response = send_get_request(streamable_url, Some(headers)) .await .unwrap(); @@ -306,7 +316,7 @@ async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwes #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_messages() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -345,7 +355,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .unwrap(); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -368,7 +378,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -429,14 +439,14 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { // read two events from the sse stream let events = read_sse_event(response, 2).await.unwrap(); - let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0].2).unwrap(); let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { panic!("invalid message received!"); }; - let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1].2).unwrap(); let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { @@ -453,7 +463,7 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { #[tokio::test] async fn should_not_close_get_sse_stream() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -472,7 +482,7 @@ async fn should_not_close_get_sse_stream() { let mut stream = response.bytes_stream(); let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -501,7 +511,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification_2, @@ -524,10 +534,10 @@ async fn should_not_close_get_sse_stream() { #[tokio::test] async fn should_reject_second_sse_stream_for_the_same_session() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); - let second_response = get_standalone_stream(&server.streamable_url, &session_id).await; + let second_response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(second_response.status(), StatusCode::CONFLICT); let error_data: SdkError = second_response.json().await.unwrap(); @@ -713,7 +723,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_2.status(), StatusCode::OK); let events = read_sse_event(response_2, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -729,7 +739,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() ); let events = read_sse_event(response_1, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1080,7 +1090,7 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { ); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0].2).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); @@ -1358,5 +1368,177 @@ async fn should_skip_all_validations_when_false() { server.hyper_runtime.await_server().await.unwrap() } -//TODO: +// should store and include event IDs in server SSE messages +#[tokio::test] +async fn should_store_and_include_event_ids_in_server_sse_messages() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // read two events + let events = read_sse_event(response, 2).await.unwrap(); + assert_eq!(events.len(), 2); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + let (second_id, _, _) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + assert!(second_id.is_some()); + + //messages should be stored and accessible + let events = server + .event_store + .unwrap() + .events_after(first_id) + .await + .unwrap(); + assert_eq!(events.messages.len(), 1); + + // deserialize the message returned by event_store + let message: ServerJsonrpcNotification = serde_json::from_str(&events.messages[0]).unwrap(); + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification2, + )) = message.notification + else { + panic!("invalid message in store!"); + }; + assert_eq!(notification2.params.data.as_str().unwrap(), "notification2"); +} + +// should store and replay MCP server tool notifications +#[tokio::test] +async fn should_store_and_replay_mcp_server_tool_notifications() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let events = read_sse_event(response, 1).await.unwrap(); + assert_eq!(events.len(), 1); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + + // sse connection is closed in read_sse_event() + // wait so server detect the disconnect and simulate a network error + tokio::time::sleep(Duration::from_secs(3)).await; + tokio::task::yield_now().await; + // we send another notification while SSE is disconnected + let _result = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // make a new standalone SSE connection to simulate a re-connection + let response = + get_standalone_stream(&server.streamable_url, &session_id, Some(&first_id)).await; + assert_eq!(response.status(), StatusCode::OK); + let events = read_sse_event(response, 1).await.unwrap(); + + assert_eq!(events.len(), 1); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification2"); +} + // should return 400 error for invalid JSON-RPC messages +// should keep stream open after sending server notifications +// NA: should reject second initialization request +// NA: should pass request info to tool callback +// NA: should reject second SSE stream even in stateless mode +// should reject requests to uninitialized server +// should accept requests with matching protocol version +// should accept when protocol version differs from negotiated version +// should call a tool with authInfo +// should calls tool without authInfo when it is optional +// should accept pre-parsed request body +// should handle pre-parsed batch messages +// should prefer pre-parsed body over request body +// should operate without session ID validation +// should handle POST requests with various session IDs in stateless mode +// should call onsessionclosed callback when session is closed via DELETE +// should not call onsessionclosed callback when not provided +// should not call onsessionclosed callback for invalid session DELETE +// should call onsessionclosed callback with correct session ID when multiple sessions exist +// should support async onsessioninitialized callback +// should support sync onsessioninitialized callback (backwards compatibility) +// should support async onsessionclosed callback +// should propagate errors from async onsessioninitialized callback +// should propagate errors from async onsessionclosed callback +// should handle both async callbacks together +// should validate both host and origin when both are configured diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index 8d55bd0..0a1e8f3 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -457,10 +457,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs index c318649..edda062 100644 --- a/crates/rust-mcp-transport/src/client_streamable_http.rs +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -496,10 +496,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/event_store.rs b/crates/rust-mcp-transport/src/event_store.rs new file mode 100644 index 0000000..fdc0734 --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store.rs @@ -0,0 +1,27 @@ +mod in_memory_event_store; +use async_trait::async_trait; +pub use in_memory_event_store::*; + +use crate::{EventId, SessionId, StreamId}; + +#[derive(Debug, Clone)] +pub struct EventStoreMessages { + pub session_id: SessionId, + pub stream_id: StreamId, + pub messages: Vec, +} + +#[async_trait] +pub trait EventStore: Send + Sync { + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: String, + ) -> EventId; + async fn remove_by_session_id(&self, session_id: SessionId); + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId); + async fn clear(&self); + async fn events_after(&self, last_event_id: EventId) -> Option; +} diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs new file mode 100644 index 0000000..66e738c --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -0,0 +1,274 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::collections::VecDeque; +use tokio::sync::RwLock; + +use crate::{ + event_store::{EventStore, EventStoreMessages}, + EventId, SessionId, StreamId, +}; + +const MAX_EVENTS_PER_SESSION: usize = 64; +const ID_SEPARATOR: &str = "-.-"; + +#[derive(Debug, Clone)] +struct EventEntry { + pub stream_id: StreamId, + pub time_stamp: u128, + pub message: String, +} + +#[derive(Debug)] +pub struct InMemoryEventStore { + max_events_per_session: usize, + storage_map: RwLock>>, +} + +impl Default for InMemoryEventStore { + fn default() -> Self { + Self { + max_events_per_session: MAX_EVENTS_PER_SESSION, + storage_map: Default::default(), + } + } +} + +/// In-memory implementation of the `EventStore` trait for MCP's Streamable HTTP transport. +/// +/// Stores events in a `HashMap` of session IDs to `VecDeque`s of events, with a per-session limit. +/// Events are identified by `event_id` (format: `session-.-stream-.-timestamp`) and used for SSE resumption. +/// Thread-safe via `RwLock` for concurrent access. +impl InMemoryEventStore { + /// Creates a new `InMemoryEventStore` with an optional maximum events per session. + /// + /// # Arguments + /// - `max_events_per_session`: Maximum number of events per session. Defaults to `MAX_EVENTS_PER_SESSION` (32) if `None`. + /// + /// # Returns + /// A new `InMemoryEventStore` instance with an empty `HashMap` wrapped in a `RwLock`. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(Some(10)); + /// assert_eq!(store.max_events_per_session, 10); + /// ``` + pub fn new(max_events_per_session: Option) -> Self { + Self { + max_events_per_session: max_events_per_session.unwrap_or(MAX_EVENTS_PER_SESSION), + storage_map: RwLock::new(HashMap::new()), + } + } + + /// Generates an `event_id` string from session, stream, and timestamp components. + /// + /// Format: `session-.-stream-.-timestamp`, used as a resumption cursor in SSE (`Last-Event-ID`). + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// + /// # Returns + /// A `String` in the format `session-.-stream-.-timestamp`. + fn generate_event_id( + &self, + session_id: &SessionId, + stream_id: &StreamId, + time_stamp: u128, + ) -> String { + format!("{session_id}{ID_SEPARATOR}{stream_id}{ID_SEPARATOR}{time_stamp}") + } + + /// Parses an event ID into its session, stream, and timestamp components. + /// + /// The event ID must follow the format `session-.-stream-.-timestamp`. + /// Returns `None` if the format is invalid, empty, or contains invalid characters (e.g., NULL). + /// + /// # Arguments + /// - `event_id`: The event ID string to parse. + /// + /// # Returns + /// An `Option` containing a tuple of `(session_id, stream_id, time_stamp)` as string slices, + /// or `None` if the format is invalid. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(None); + /// let event_id = "session1-.-stream1-.-12345"; + /// assert_eq!( + /// store.parse_event_id(event_id), + /// Some(("session1", "stream1", "12345")) + /// ); + /// assert_eq!(store.parse_event_id("invalid"), None); + /// ``` + pub fn parse_event_id<'a>(&self, event_id: &'a str) -> Option<(&'a str, &'a str, &'a str)> { + // Check for empty input or invalid characters (e.g., NULL) + if event_id.is_empty() || event_id.contains('\0') { + return None; + } + + // Split into exactly three parts + let parts: Vec<&'a str> = event_id.split(ID_SEPARATOR).collect(); + if parts.len() != 3 { + return None; + } + + let session_id = parts[0]; + let stream_id = parts[1]; + let time_stamp = parts[2]; + + // Ensure no part is empty + if session_id.is_empty() || stream_id.is_empty() || time_stamp.is_empty() { + return None; + } + + Some((session_id, stream_id, time_stamp)) + } +} + +#[async_trait] +impl EventStore for InMemoryEventStore { + /// Stores an event for a given session and stream, returning its `event_id`. + /// + /// Adds the event to the session’s `VecDeque`, removing the oldest event if the session + /// reaches `max_events_per_session`. + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// - `message`: The `ServerMessages` payload. + /// + /// # Returns + /// The generated `EventId` for the stored event. + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: String, + ) -> EventId { + let event_id = self.generate_event_id(&session_id, &stream_id, time_stamp); + + let mut storage_map = self.storage_map.write().await; + + tracing::trace!( + "Storing event for session: {session_id}, stream_id: {stream_id}, message: '{message}', {time_stamp} ", + ); + + let session_map = storage_map + .entry(session_id) + .or_insert_with(|| VecDeque::with_capacity(self.max_events_per_session)); + + if session_map.len() == self.max_events_per_session { + session_map.pop_front(); // remove the oldest if full + } + + let entry = EventEntry { + stream_id, + time_stamp, + message, + }; + + session_map.push_back(entry); + + event_id + } + + /// Removes all events associated with a given stream ID within a specific session. + /// + /// Removes events matching `stream_id` from the specified `session_id`’s event queue. + /// If the session’s queue becomes empty, it is removed from the store. + /// Idempotent if `session_id` or `stream_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to target. + /// - `stream_id`: The stream identifier to remove. + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId) { + let mut storage_map = self.storage_map.write().await; + + // Check if session exists + if let Some(events) = storage_map.get_mut(&session_id) { + // Remove events with the given stream_id + events.retain(|event| event.stream_id != stream_id); + // Remove session if empty + if events.is_empty() { + storage_map.remove(&session_id); + } + } + // No action if session_id doesn’t exist (idempotent) + } + + /// Removes all events associated with a given session ID. + /// + /// Removes the entire session from the store. Idempotent if `session_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to remove. + async fn remove_by_session_id(&self, session_id: SessionId) { + let mut storage_map = self.storage_map.write().await; + storage_map.remove(&session_id); + } + + /// Retrieves events after a given `event_id` for a specific session and stream. + /// + /// Parses `last_event_id` to extract `session_id`, `stream_id`, and `time_stamp`. + /// Returns events after the matching event in the session’s stream, sorted by timestamp + /// in ascending order (earliest to latest). Returns `None` if the `event_id` is invalid, + /// the session doesn’t exist, or the timestamp is non-numeric. + /// + /// # Arguments + /// - `last_event_id`: The event ID (format: `session-.-stream-.-timestamp`) to start after. + /// + /// # Returns + /// An `Option` containing `EventStoreMessages` with the session ID, stream ID, and sorted messages, + /// or `None` if no events are found or the input is invalid. + async fn events_after(&self, last_event_id: EventId) -> Option { + let Some((session_id, stream_id, time_stamp)) = self.parse_event_id(&last_event_id) else { + tracing::warn!("error parsing last event id: '{last_event_id}'"); + return None; + }; + + let storage_map = self.storage_map.read().await; + let Some(events) = storage_map.get(session_id) else { + tracing::warn!("could not find the session_id in the store : '{session_id}'"); + return None; + }; + + let Ok(time_stamp) = time_stamp.parse::() else { + tracing::warn!("could not parse the timestamp: '{time_stamp}'"); + return None; + }; + + let events = match events + .iter() + .position(|e| e.stream_id == stream_id && e.time_stamp == time_stamp) + { + Some(index) if index + 1 < events.len() => { + // Collect subsequent events that match the stream_id + let mut subsequent: Vec<_> = events + .range(index + 1..) + .filter(|e| e.stream_id == stream_id) + .cloned() + .collect(); + + subsequent.sort_by(|a, b| a.time_stamp.cmp(&b.time_stamp)); + subsequent.iter().map(|e| e.message.clone()).collect() + } + _ => vec![], + }; + + tracing::trace!("{} messages after '{last_event_id}'", events.len()); + + Some(EventStoreMessages { + session_id: session_id.to_string(), + stream_id: stream_id.to_string(), + messages: events, + }) + } + + async fn clear(&self) { + let mut storage_map = self.storage_map.write().await; + storage_map.clear(); + } +} diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 4a918db..d21e5dd 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -8,6 +8,7 @@ mod client_sse; mod client_streamable_http; mod constants; pub mod error; +pub mod event_store; mod mcp_stream; mod message_dispatcher; mod schema; diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 7c7c93e..cd9727c 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -1,13 +1,20 @@ -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, ServerMessages, +use crate::error::{TransportError, TransportResult}; +use crate::schema::{RequestId, RpcError}; +use crate::utils::{await_timeout, current_timestamp}; +use crate::McpDispatch; +use crate::{ + event_store::EventStore, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, + ServerMessages, + }, + JsonrpcError, }, - JsonrpcError, + SessionId, StreamId, }; -use crate::schema::{RequestId, RpcError}; use async_trait::async_trait; use futures::future::join_all; - use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -16,9 +23,7 @@ use tokio::io::AsyncWriteExt; use tokio::sync::oneshot::{self}; use tokio::sync::Mutex; -use crate::error::{TransportError, TransportResult}; -use crate::utils::await_timeout; -use crate::McpDispatch; +pub const ID_SEPARATOR: u8 = b'|'; /// Provides a dispatcher for sending MCP messages and handling responses. /// @@ -37,6 +42,10 @@ pub struct MessageDispatcher { )>, >, request_timeout: Duration, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } impl MessageDispatcher { @@ -60,6 +69,9 @@ impl MessageDispatcher { writable_std: Some(writable_std), writable_tx: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } @@ -76,9 +88,25 @@ impl MessageDispatcher { writable_tx: Some(writable_tx), writable_std: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } + async fn store_pending_request( &self, request_id: RequestId, @@ -141,7 +169,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; if let Some(rx) = rx_response { // Wait for the response with timeout @@ -177,7 +205,7 @@ impl McpDispatch let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; // no request in the batch, no need to wait for the result if request_ids.is_empty() { @@ -233,7 +261,7 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, _skip_store: bool) -> TransportResult<()> { if let Some(writable_std) = self.writable_std.as_ref() { let mut writable_std = writable_std.lock().await; writable_std.write_all(payload.as_bytes()).await?; @@ -289,7 +317,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; if let Some(rx) = rx_response { match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { @@ -317,7 +345,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; // no request in the batch, no need to wait for the result if pending_tasks.is_empty() { @@ -375,9 +403,34 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let mut event_id = None; + + if !skip_store && !payload.trim().is_empty() { + if let (Some(session_id), Some(stream_id), Some(event_store)) = ( + self.session_id.as_ref(), + self.stream_id.as_ref(), + self.event_store.as_ref(), + ) { + event_id = Some( + event_store + .store_event( + session_id.clone(), + stream_id.clone(), + current_timestamp(), + payload.to_owned(), + ) + .await, + ) + }; + } + if let Some(writable_std) = self.writable_std.as_ref() { let mut writable_std = writable_std.lock().await; + if let Some(id) = event_id { + writable_std.write_all(id.as_bytes()).await?; + writable_std.write_all(&[ID_SEPARATOR]).await?; // separate id from message + } writable_std.write_all(payload.as_bytes()).await?; writable_std.write_all(b"\n").await?; // new line writable_std.flush().await?; diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 09809e4..89ca67f 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -1,3 +1,4 @@ +use crate::event_store::EventStore; use crate::schema::schema_utils::{ ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, }; @@ -19,7 +20,7 @@ use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; -use crate::{IoStream, McpDispatch, SessionId, TransportDispatcher, TransportOptions}; +use crate::{IoStream, McpDispatch, SessionId, StreamId, TransportDispatcher, TransportOptions}; pub struct SseTransport where @@ -33,6 +34,10 @@ where message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, pending_requests: Arc>>>, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } /// Server-Sent Events (SSE) transport implementation @@ -67,6 +72,9 @@ where message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: None, + stream_id: None, + event_store: None, }) } @@ -86,6 +94,19 @@ where let mut lock = self.error_stream.write().await; *lock = Some(IoStream::Writable(error_stream)); } + + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } } #[async_trait] @@ -123,10 +144,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -161,7 +182,7 @@ impl Transport( + let (stream, mut sender, error_stream) = MCPStream::create::( Box::pin(read_rx), Mutex::new(Box::pin(write_tx)), IoStream::Writable(Box::pin(tokio::io::stderr())), @@ -170,6 +191,18 @@ impl Transport {} Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 11bd0a6..7678c65 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -348,10 +348,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -400,10 +400,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index b8e3ddc..a9e7190 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -82,7 +82,7 @@ where /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()>; + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()>; } /// A trait representing the transport layer for the MCP (Message Communication Protocol). diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 82d7326..034f062 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -25,6 +25,8 @@ pub(crate) use sse_stream::*; pub(crate) use streamable_http_stream::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; +mod time_utils; +pub use time_utils::*; use crate::schema::schema_utils::SdkError; use tokio::time::{timeout, Duration}; diff --git a/crates/rust-mcp-transport/src/utils/time_utils.rs b/crates/rust-mcp-transport/src/utils/time_utils.rs new file mode 100644 index 0000000..25c4f5d --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/time_utils.rs @@ -0,0 +1,8 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn current_timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Invalid time") + .as_nanos() +} diff --git a/examples/hello-world-server-streamable-http-core/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs index 7b41c70..81a6ae5 100644 --- a/examples/hello-world-server-streamable-http-core/src/main.rs +++ b/examples/hello-world-server-streamable-http-core/src/main.rs @@ -1,7 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, @@ -48,6 +51,7 @@ async fn main() -> SdkResult<()> { handler, HyperServerOptions { sse_support: true, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/examples/hello-world-server-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http/src/main.rs index cd8c658..3923a6d 100644 --- a/examples/hello-world-server-streamable-http/src/main.rs +++ b/examples/hello-world-server-streamable-http/src/main.rs @@ -1,8 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; use std::time::Duration; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions}; use handler::MyServerHandler; @@ -57,6 +59,7 @@ async fn main() -> SdkResult<()> { HyperServerOptions { host: "127.0.0.1".to_string(), ping_interval: Duration::from_secs(5), + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); From 3ab5fe73aaa10de2b5b23caee357ac15b37c845f Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 18 Sep 2025 19:23:31 -0300 Subject: [PATCH 29/33] feat: add elicitation macros and add elicit_input() method (#99) * Add Streamable HTTP Client and multiple refactoring and improvements * chore: typos * chore: update readme * feat: introduce event-store * chore: add event store to the app state * chore: refactor event store integration * chore: add tracing to inmemory store * chore: update examples to use event store * chore: improve flow * chore: replay mechanism * cleanup * test: add new test for event-store * chore: add tracing to tests * chore: add test * chore: refactor replaying logic * chore: cleanup * feat: add elicit_input to the McpServer * chore: enhance jsonschema macro * feat: introduce mcp_elicit macro * feat: add default, minimu, macimum support * improve tests and enum support * implement from_content_map * update docs * update readme * cleanup * fix: tests * update to latest rust-mcp-schema * chore: issues * chore: update readme * update readme * chore: typo --- Cargo.lock | 235 +++---- README.md | 118 +++- crates/rust-mcp-macros/README.md | 110 +++- crates/rust-mcp-macros/src/lib.rs | 609 ++++++++++++++++-- crates/rust-mcp-macros/src/utils.rs | 303 ++++++++- crates/rust-mcp-macros/tests/common/common.rs | 50 ++ crates/rust-mcp-macros/tests/macro_test.rs | 241 +++++++ crates/rust-mcp-sdk/README.md | 118 +++- .../src/hyper_servers/app_state.rs | 2 + .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 24 +- crates/rust-mcp-sdk/tests/common/common.rs | 2 +- .../tests/test_streamable_http_client.rs | 4 +- .../tests/test_streamable_http_server.rs | 2 +- .../hello-world-mcp-server-stdio/src/tools.rs | 29 +- 14 files changed, 1609 insertions(+), 238 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3c4462..6ee3950 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -84,9 +84,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.13.3" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba" +checksum = "94b8ff6c09cd57b16da53641caa860168b88c172a5ee163b0288d3d6eea12786" dependencies = [ "aws-lc-sys", "zeroize", @@ -94,9 +94,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" +checksum = "0e44d16778acaf6a9ec9899b92cebd65580b83f685446bf2e1f5d3d732f99dcd" dependencies = [ "bindgen", "cc", @@ -216,32 +216,29 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.5" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ "bitflags", "cexpr", "clang-sys", "itertools", - "lazy_static", - "lazycell", "log", "prettyplease", "proc-macro2", "quote", "regex", - "rustc-hash 1.1.0", + "rustc-hash", "shlex", "syn", - "which", ] [[package]] name = "bitflags" -version = "2.9.3" +version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" [[package]] name = "bumpalo" @@ -257,9 +254,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.35" +version = "1.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" +checksum = "65193589c6404eb80b450d618eaf9a2cafaaafd57ecce47370519ef674a7bd44" dependencies = [ "find-msvc-tools", "jobserver", @@ -427,16 +424,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" -[[package]] -name = "errno" -version = "0.3.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" -dependencies = [ - "libc", - "windows-sys 0.60.2", -] - [[package]] name = "event-listener" version = "2.5.3" @@ -454,9 +441,9 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" +checksum = "7fd99930f64d146689264c637b5af2f0233a933bef0d8570e2526bf9e083192d" [[package]] name = "fnv" @@ -475,9 +462,9 @@ dependencies = [ [[package]] name = "fs-err" -version = "3.1.1" +version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d7be93788013f265201256d58f04936a8079ad5dc898743aa20525f503b683" +checksum = "44f150ffc8782f35521cec2b23727707cb4045706ba3c854e86bef66b3a8cdbd" dependencies = [ "autocfg", "tokio", @@ -633,7 +620,7 @@ dependencies = [ "js-sys", "libc", "r-efi", - "wasi 0.14.3+wasi-0.2.4", + "wasi 0.14.7+wasi-0.2.4", "wasm-bindgen", ] @@ -753,15 +740,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" -[[package]] -name = "home" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" -dependencies = [ - "windows-sys 0.59.0", -] - [[package]] name = "http" version = "0.2.12" @@ -917,9 +895,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ "base64 0.22.1", "bytes", @@ -1048,9 +1026,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.11.0" +version = "2.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" +checksum = "92119844f513ffa41556430369ab02c295a3578af21cf945caa3e9e0c2481ac3" dependencies = [ "equivalent", "hashbrown", @@ -1100,9 +1078,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -1125,9 +1103,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "852f13bec5eba4ba9afbeb93fd7c13fe56147f055939ae21c43a29a0ecb2702e" dependencies = [ "once_cell", "wasm-bindgen", @@ -1139,12 +1117,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "libc" version = "0.2.175" @@ -1161,12 +1133,6 @@ dependencies = [ "windows-targets 0.53.3", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "litemap" version = "0.8.0" @@ -1191,9 +1157,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" [[package]] name = "lru-slab" @@ -1431,7 +1397,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "socket2 0.6.0", "thiserror 2.0.16", @@ -1451,7 +1417,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "rustls-pki-types", "slab", @@ -1676,9 +1642,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "098436b06bfa4b88b110d12a5567cf37fd454735ee67cab7eb48bdbea0dd0e57" +checksum = "0bb65fd293dbbfabaacba1512b3948cdd9bf31ad1f2c0fed4962052b590c5c44" dependencies = [ "serde", "serde_json", @@ -1733,31 +1699,12 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.59.0", -] - [[package]] name = "rustls" version = "0.23.31" @@ -1794,9 +1741,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.4" +version = "0.103.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" dependencies = [ "aws-lc-rs", "ring", @@ -1824,18 +1771,28 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.225" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516" dependencies = [ "proc-macro2", "quote", @@ -1844,24 +1801,26 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.143" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] name = "serde_path_to_error" -version = "0.1.17" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" dependencies = [ "itoa", "serde", + "serde_core", ] [[package]] @@ -2129,9 +2088,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.42" +version = "0.3.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ca967379f9d8eb8058d86ed467d81d03e81acd45757e4ca341c24affbe8e8e3" +checksum = "83bde6f1ec10e72d583d91623c939f623002284ef622b87de38cfd546cbf2031" dependencies = [ "deranged", "num-conv", @@ -2143,15 +2102,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9108bb380861b07264b950ded55a44a14a4adc68b9f5efd85aafc3aa4d40a68" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7182799245a7264ce590b349d90338f1c1affad93d2639aed5f8f69c090b334c" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", @@ -2215,9 +2174,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +checksum = "05f63835928ca123f1bef57abbcd23bb2ba0ac9ae1235f1e65bda0d06e7786bd" dependencies = [ "rustls", "tokio", @@ -2369,9 +2328,9 @@ checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" [[package]] name = "untrusted" @@ -2399,9 +2358,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.18.0" +version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f33196643e165781c20a5ead5582283a7dacbb87855d867fbc2df3f81eddc1be" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ "getrandom 0.3.3", "js-sys", @@ -2449,30 +2408,40 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.3+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51ae83037bdd272a9e28ce236db8c07016dd0d50c27038b3f407533c030c95" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "ab10a69fbd0a177f5f649ad4d8d3305499c42bab9aef2f7ff592d0ec8f833819" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +checksum = "0bb702423545a6007bbc368fde243ba47ca275e549c8a28617f56f6ba53b1d1c" dependencies = [ "bumpalo", "log", @@ -2484,9 +2453,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.50" +version = "0.4.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" +checksum = "a0b221ff421256839509adbb55998214a70d829d3a28c69b4a6672e9d2a42f67" dependencies = [ "cfg-if", "js-sys", @@ -2497,9 +2466,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "fc65f4f411d91494355917b605e1480033152658d71f722a90647f56a70c88a0" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2507,9 +2476,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "ffc003a991398a8ee604a401e194b6b3a39677b3173d6e74495eb51b82e99a32" dependencies = [ "proc-macro2", "quote", @@ -2520,9 +2489,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "293c37f4efa430ca14db3721dfbe48d8c33308096bd44d80ebaa775ab71ba1cf" dependencies = [ "unicode-ident", ] @@ -2542,9 +2511,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "fbe734895e869dc429d78c4b433f8d17d95f8d05317440b4fad5ab2d33e596dc" dependencies = [ "js-sys", "wasm-bindgen", @@ -2569,18 +2538,6 @@ dependencies = [ "rustls-pki-types", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix", -] - [[package]] name = "windows-link" version = "0.1.3" @@ -2767,9 +2724,9 @@ dependencies = [ [[package]] name = "wit-bindgen" -version = "0.45.0" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "writeable" @@ -2803,18 +2760,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", diff --git a/README.md b/README.md index 51c3b49..2c70c3e 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,8 @@ This project supports following transports: - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response - βœ… Streamable HTTP Support for MCP Clients -- ⬜ Resumability -- ⬜ Authentication / Oauth +- βœ… Resumability +- ⬜ Oauth Authentication **⚠️** Project is currently under development and should be used at your own risk. @@ -50,6 +50,7 @@ This project supports following transports: - [MCP Client (stdio)](#mcp-client-stdio) - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) +- [Macros](#macros) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) - [Security Considerations](#security-considerations) @@ -386,6 +387,114 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. +## Macros +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. + +> To use these macros, ensure the `macros` feature is enabled in your Cargo.toml. + +### mcp_tool +`mcp_tool` is a procedural macro attribute that helps generating rust_mcp_schema::Tool from a struct. + +Usage example: +```rust +#[mcp_tool( + name = "move_file", + title="Move File", + description = concat!("Move or rename files and directories. Can move files between directories ", +"and rename them in a single operation. If the destination exists, the ", +"operation will fail. Works across different directories and can be used ", +"for simple renaming within the same directory. ", +"Both source and destination must be within allowed directories."), + destructive_hint = false, + idempotent_hint = false, + open_world_hint = false, + read_only_hint = false +)] +#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] +pub struct MoveFileTool { + /// The source path of the file to move. + pub source: String, + /// The destination path to move the file to. + pub destination: String, +} + +// Now we can call `tool()` method on it to get a Tool instance +let rust_mcp_sdk::schema::Tool = MoveFileTool::tool(); + +``` + +πŸ’» For a real-world example, check out any of the tools available at: https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools + + +### tool_box +`tool_box` generates an enum from a provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. + +It accepts an array of tools and generates an enum where each tool becomes a variant of the enum. + +Generated enum has a `tools()` function that returns a `Vec` , and a `TryFrom` trait implementation that could be used to convert a ToolRequest into a Tool instance. + +Usage example: +```rust + // Accepts an array of tools and generates an enum named `FileSystemTools`, + // where each tool becomes a variant of the enum. + tool_box!(FileSystemTools, [ReadFileTool, MoveFileTool, SearchFilesTool]); + + // now in the app, we can use the FileSystemTools, like: + let all_tools: Vec = FileSystemTools::tools(); +``` + +πŸ’» To see a real-world example of that please see : +- `tool_box` macro usage: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs) +- using `tools()` in list tools request : [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L67) +- using `try_from` in call tool_request: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L100) + + + +### mcp_elicit +The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: + +- A `message()` method returning the elicitation message as a string. +- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. +- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. + +### Attributes + +- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. + +Usage example: +```rust +// A struct that could be used to send elicit request and get the input from the user +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} + +// send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance +let result: ElicitResult = server + .elicit_input(UserInfo::message(), UserInfo::requested_schema()) + .await?; + +// Create a UserInfo instance using data provided by the user on the client side +let user_info = UserInfo::from_content_map(result.content)?; + +``` + +πŸ’» For mre info please see : +- https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros + ## Getting Started If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) @@ -509,6 +618,7 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. + #### MCP Protocol Versions with Corresponding Features - `2025_06_18` : Activates MCP Protocol version 2025-06-18 (enabled by default) @@ -621,6 +731,10 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na + + + + ## Contributing We welcome everyone who wishes to contribute! Please refer to the [contributing](CONTRIBUTING.md) guidelines for more details. diff --git a/crates/rust-mcp-macros/README.md b/crates/rust-mcp-macros/README.md index 92da2c3..fc463cd 100644 --- a/crates/rust-mcp-macros/README.md +++ b/crates/rust-mcp-macros/README.md @@ -1,5 +1,8 @@ # rust-mcp-macros. + +## mcp_tool Macro + A procedural macro, part of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) ecosystem, to generate `rust_mcp_schema::Tool` instance from a struct. The `mcp_tool` macro generates an implementation for the annotated struct that includes: @@ -80,11 +83,7 @@ fn main() { ``` ---- - Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) , a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest! - ---- **Note**: The following attributes are available only in version `2025_03_26` and later of the MCP Schema, and their values will be used in the [annotations](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5557) attribute of the *[Tool struct](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5554-L5566). @@ -93,3 +92,106 @@ fn main() { - `idempotent_hint` - `open_world_hint` - `read_only_hint` + + + + + +## mcp_elicit Macro + +The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: + +- A `message()` method returning the elicitation message as a string. +- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. +- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. + +### Attributes + +- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. + +### Supported Field Types + +- `String`: Maps to `ElicitResultContentValue::String`. +- `bool`: Maps to `ElicitResultContentValue::Boolean`. +- `i32`: Maps to `ElicitResultContentValue::Integer` (with bounds checking). +- `i64`: Maps to `ElicitResultContentValue::Integer`. +- `enum` Only simple enums are supported. The enum must implement the FromStr trait. +- `Option`: Supported for any of the above types, mapping to `None` if the field is missing. + + +### Usage Example + +```rust +use rust_mcp_sdk::macros::{mcp_elicit, JsonSchema}; +use rust_mcp_sdk::schema::RpcError; +use std::str::FromStr; + +// Simple enum with FromStr trait implemented +#[derive(JsonSchema, Debug)] +pub enum Colors { + #[json_schema(title = "Green Color")] + Green, + #[json_schema(title = "Red Color")] + Red, +} +impl FromStr for Colors { + type Err = RpcError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "green" => Ok(Colors::Green), + "red" => Ok(Colors::Red), + _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), + } + } +} + +// A struct that could be used to send elicit request and get the input from the user +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + + /// Email address of the user + #[json_schema(title = "Email", format = "email")] + pub email: Option, + + /// The user's age in years + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} + + // .... + // ....... + // ........... + + // send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance + + let result: ElicitResult = server + .elicit_input(UserInfo::message(), UserInfo::requested_schema()) + .await?; + + // Create a UserInfo instance using data provided by the user on the client side + let user_info = UserInfo::from_content_map(result.content)?; + + +``` + +--- + + Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk), a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest! + +--- diff --git a/crates/rust-mcp-macros/src/lib.rs b/crates/rust-mcp-macros/src/lib.rs index 35d6e55..473792c 100644 --- a/crates/rust-mcp-macros/src/lib.rs +++ b/crates/rust-mcp-macros/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr, - ExprLit, Fields, Lit, Meta, Token, + ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments, Token, Type, }; use utils::{is_option, renamed_field, type_to_json_schema}; @@ -45,6 +45,8 @@ struct McpToolMacroAttributes { use syn::parse::ParseStream; +use crate::utils::{generate_enum_parse, is_enum}; + struct ExprList { exprs: Punctuated, } @@ -246,6 +248,66 @@ impl Parse for McpToolMacroAttributes { } } +struct McpElicitationAttributes { + message: Option, +} + +impl Parse for McpElicitationAttributes { + fn parse(attributes: syn::parse::ParseStream) -> syn::Result { + let mut instance = Self { message: None }; + let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?; + for meta in meta_list { + if let Meta::NameValue(meta_name_value) = meta { + let ident = meta_name_value.path.get_ident().unwrap(); + let ident_str = ident.to_string(); + if ident_str.as_str() == "message" { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + Expr::Macro(expr_macro) => { + let mac = &expr_macro.mac; + if mac.path.is_ident("concat") { + let args: ExprList = syn::parse2(mac.tokens.clone())?; + let mut result = String::new(); + for expr in args.exprs { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = expr + { + result.push_str(&lit_str.value()); + } else { + return Err(Error::new_spanned( + expr, + "Only string literals are allowed inside concat!()", + )); + } + } + result + } else { + return Err(Error::new_spanned( + expr_macro, + "Only concat!(...) is supported here", + )); + } + } + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a string literal or concat!(...)", + )); + } + }; + instance.message = Some(value) + } + } + } + Ok(instance) + } +} + /// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct. /// /// The `mcp_tool` macro generates an implementation for the annotated struct that includes: @@ -387,7 +449,7 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { let output = quote! { impl #input_ident { - /// Returns the name of the tool as a string. + /// Returns the name of the tool as a String. pub fn tool_name() -> String { #tool_name.to_string() } @@ -404,7 +466,7 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { .iter() .filter_map(|item| item.as_str().map(String::from)) .collect(), - None => Vec::new(), // Default to an empty vector if "required" is missing or not an array + None => Vec::new(), }; let properties: Option< @@ -440,6 +502,303 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { TokenStream::from(output) } +#[proc_macro_attribute] +pub fn mcp_elicit(attributes: TokenStream, input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let input_ident = &input.ident; + + // Conditionally select the path + let base_crate = if cfg!(feature = "sdk") { + quote! { rust_mcp_sdk::schema } + } else { + quote! { rust_mcp_schema } + }; + + let macro_attributes = parse_macro_input!(attributes as McpElicitationAttributes); + let message = macro_attributes.message.unwrap_or_default(); + + // Generate field assignments for from_content_map() + let field_assignments = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => { + let assignments = fields.named.iter().map(|field| { + let field_attrs = &field.attrs; + let field_ident = &field.ident; + let renamed_field = renamed_field(field_attrs); + let field_name = renamed_field.unwrap_or_else(|| field_ident.as_ref().unwrap().to_string()); + let field_type = &field.ty; + + let type_check = if is_option(field_type) { + // Extract inner type for Option + let inner_type = match field_type { + Type::Path(type_path) => { + let segment = type_path.path.segments.last().unwrap(); + if segment.ident == "Option" { + match &segment.arguments { + PathArguments::AngleBracketed(args) => { + match args.args.first().unwrap() { + GenericArgument::Type(ty) => ty, + _ => panic!("Expected type argument in Option"), + } + } + _ => panic!("Invalid Option type"), + } + } else { + panic!("Expected Option type"); + } + } + _ => panic!("Expected Option type"), + }; + // Determine the match arm based on the inner type at compile time + let (inner_type_ident, match_pattern, conversion) = match inner_type { + Type::Path(type_path) if type_path.path.is_ident("String") => ( + quote! { String }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { s.clone() } + ), + Type::Path(type_path) if type_path.path.is_ident("bool") => ( + quote! { bool }, + quote! { #base_crate::ElicitResultContentValue::Boolean(b) }, + quote! { *b } + ), + Type::Path(type_path) if type_path.path.is_ident("i32") => ( + quote! { i32 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { + (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( + "Invalid number for field '{}': value {} does not fit in i32", + #field_name, *i + )))? + } + ), + Type::Path(type_path) if type_path.path.is_ident("i64") => ( + quote! { i64 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { *i } + ), + _ if is_enum(inner_type, &input) => { + let enum_parse = generate_enum_parse(inner_type, &field_name, &base_crate); + ( + quote! { #inner_type }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { #enum_parse } + ) + } + _ => panic!("Unsupported inner type for Option field: {}", quote! { #inner_type }), + }; + let inner_type_str = quote! { stringify!(#inner_type_ident) }; + quote! { + let #field_ident: Option<#inner_type_ident> = match content.as_ref().and_then(|map| map.get(#field_name)) { + Some(value) => { + match value { + #match_pattern => Some(#conversion), + _ => { + return Err(#base_crate::RpcError::parse_error().with_message(format!( + "Type mismatch for field '{}': expected {}, found {}", + #field_name, #inner_type_str, + match value { + #base_crate::ElicitResultContentValue::Boolean(_) => "boolean", + #base_crate::ElicitResultContentValue::String(_) => "string", + #base_crate::ElicitResultContentValue::Integer(_) => "integer", + } + ))); + } + } + } + None => None, + }; + } + } else { + // Determine the match arm based on the field type at compile time + let (field_type_ident, match_pattern, conversion) = match field_type { + Type::Path(type_path) if type_path.path.is_ident("String") => ( + quote! { String }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { s.clone() } + ), + Type::Path(type_path) if type_path.path.is_ident("bool") => ( + quote! { bool }, + quote! { #base_crate::ElicitResultContentValue::Boolean(b) }, + quote! { *b } + ), + Type::Path(type_path) if type_path.path.is_ident("i32") => ( + quote! { i32 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { + (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( + "Invalid number for field '{}': value {} does not fit in i32", + #field_name, *i + )))? + } + ), + Type::Path(type_path) if type_path.path.is_ident("i64") => ( + quote! { i64 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { *i } + ), + _ if is_enum(field_type, &input) => { + let enum_parse = generate_enum_parse(field_type, &field_name, &base_crate); + ( + quote! { #field_type }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { #enum_parse } + ) + } + _ => panic!("Unsupported field type: {}", quote! { #field_type }), + }; + let type_str = quote! { stringify!(#field_type_ident) }; + quote! { + let #field_ident: #field_type_ident = match content.as_ref().and_then(|map| map.get(#field_name)) { + Some(value) => { + match value { + #match_pattern => #conversion, + _ => { + return Err(#base_crate::RpcError::parse_error().with_message(format!( + "Type mismatch for field '{}': expected {}, found {}", + #field_name, #type_str, + match value { + #base_crate::ElicitResultContentValue::Boolean(_) => "boolean", + #base_crate::ElicitResultContentValue::String(_) => "string", + #base_crate::ElicitResultContentValue::Integer(_) => "integer", + } + ))); + } + } + } + None => { + return Err(#base_crate::RpcError::parse_error().with_message(format!( + "Missing required field: {}", + #field_name + ))); + } + }; + } + }; + + type_check + }); + + let field_idents = fields.named.iter().map(|field| &field.ident); + + quote! { + #(#assignments)* + + Ok(Self { + #(#field_idents,)* + }) + } + } + _ => panic!("mcp_elicit macro only supports structs with named fields"), + }, + _ => panic!("mcp_elicit macro only supports structs"), + }; + + let output = quote! { + impl #input_ident { + + /// Returns the elicitation message defined in the `#[mcp_elicit(message = "...")]` attribute. + /// + /// This message is used to prompt the user or system for input when eliciting data for the struct. + /// If no message is provided in the attribute, an empty string is returned. + /// + /// # Returns + /// A `String` containing the elicitation message. + pub fn message()->String{ + #message.to_string() + } + + /// This method returns a `ElicitRequestedSchema` by retrieves the + /// struct's JSON schema (via the `JsonSchema` derive) and converting int into + /// a `ElicitRequestedSchema`. It extracts the `required` fields and + /// `properties` from the schema, mapping them to a `HashMap` of `PrimitiveSchemaDefinition` objects. + /// + /// # Returns + /// An `ElicitRequestedSchema` representing the schema of the struct. + /// + /// # Panics + /// Panics if the schema's properties cannot be converted to `PrimitiveSchemaDefinition` or if the schema + /// is malformed. + pub fn requested_schema() -> #base_crate::ElicitRequestedSchema { + let json_schema = &#input_ident::json_schema(); + + let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { + Some(arr) => arr + .iter() + .filter_map(|item| item.as_str().map(String::from)) + .collect(), + None => Vec::new(), + }; + + let properties: Option> = json_schema + .get("properties") + .and_then(|v| v.as_object()) // Safely extract "properties" as an object. + .map(|properties| { + properties + .iter() + .filter_map(|(key, value)| { + serde_json::to_value(value) + .ok() // If serialization fails, return None. + .and_then(|v| { + if let serde_json::Value::Object(obj) = v { + Some(obj) + } else { + None + } + }) + .map(|obj| (key.to_string(), #base_crate::PrimitiveSchemaDefinition::try_from(&obj))) + }) + .collect() + }); + + let properties = properties + .map(|map| { + map.into_iter() + .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple + .collect::, _>>() // collect only if all Ok + }) + .transpose() + .unwrap(); + + let properties = + properties.expect("Was not able to create a ElicitRequestedSchema"); + + let requested_schema = #base_crate::ElicitRequestedSchema::new(properties, required); + requested_schema + } + + /// Converts a map of field names and `ElicitResultContentValue` into an instance of the struct. + /// + /// This method parses the provided content map, matching field names to struct fields and converting + /// `ElicitResultContentValue` variants into the appropriate Rust types (e.g., `String`, `bool`, `i32`, + /// `i64`, or simple enums). It supports both required and optional fields (`Option`). + /// + /// # Parameters + /// - `content`: An optional `HashMap` mapping field names to `ElicitResultContentValue` values. + /// + /// # Returns + /// - `Ok(Self)` if the map is successfully parsed into the struct. + /// - `Err(RpcError)` if: + /// - A required field is missing. + /// - A value’s type does not match the expected field type. + /// - An integer value cannot be converted (e.g., `i64` to `i32` out of bounds). + /// - An enum value is invalid (e.g., string value does not match a enum variant name). + /// + /// # Errors + /// Returns `RpcError` with messages like: + /// - `"Missing required field: {}"` + /// - `"Type mismatch for field '{}': expected {}, found {}"` + /// - `"Invalid number for field '{}': value {} does not fit in i32"` + /// - `"Invalid enum value for field '{}': expected 'Yes' or 'No', found '{}'"`. + pub fn from_content_map(content: ::std::option::Option<::std::collections::HashMap<::std::string::String, #base_crate::ElicitResultContentValue>>) -> Result { + #field_assignments + } + } + #input + }; + + TokenStream::from(output) +} + /// Derives a JSON Schema representation for a struct. /// /// This procedural macro generates a `json_schema()` method for the annotated struct, returning a @@ -473,70 +832,222 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { /// # Dependencies /// Relies on `serde_json` for `Map` and `Value` types. /// -#[proc_macro_derive(JsonSchema)] +#[proc_macro_derive(JsonSchema, attributes(json_schema))] pub fn derive_json_schema(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as DeriveInput); let name = &input.ident; - let fields = match &input.data { + let schema_body = match &input.data { Data::Struct(data) => match &data.fields { - Fields::Named(fields) => &fields.named, - _ => panic!("JsonSchema derive macro only supports named fields"), + Fields::Named(fields) => { + let field_entries = fields.named.iter().map(|field| { + let field_attrs = &field.attrs; + let renamed_field = renamed_field(field_attrs); + let field_name = + renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + let field_type = &field.ty; + + let schema = type_to_json_schema(field_type, field_attrs); + quote! { + properties.insert( + #field_name.to_string(), + serde_json::Value::Object(#schema) + ); + } + }); + + let required_fields = fields.named.iter().filter_map(|field| { + let renamed_field = renamed_field(&field.attrs); + let field_name = + renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + + let field_type = &field.ty; + if !is_option(field_type) { + Some(quote! { + required.push(#field_name.to_string()); + }) + } else { + None + } + }); + + quote! { + let mut schema = serde_json::Map::new(); + let mut properties = serde_json::Map::new(); + let mut required = Vec::new(); + + #(#field_entries)* + + #(#required_fields)* + + schema.insert("type".to_string(), serde_json::Value::String("object".to_string())); + schema.insert("properties".to_string(), serde_json::Value::Object(properties)); + if !required.is_empty() { + schema.insert("required".to_string(), serde_json::Value::Array( + required.into_iter().map(serde_json::Value::String).collect() + )); + } + + schema + } + } + _ => panic!("JsonSchema derive macro only supports named fields for structs"), }, - _ => panic!("JsonSchema derive macro only supports structs"), - }; + Data::Enum(data) => { + let variant_schemas = data.variants.iter().map(|variant| { + let variant_attrs = &variant.attrs; + let variant_name = variant.ident.to_string(); + let renamed_variant = renamed_field(variant_attrs).unwrap_or(variant_name.clone()); - let field_entries = fields.iter().map(|field| { - let field_attrs = &field.attrs; - let renamed_field = renamed_field(field_attrs); - let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); - let field_type = &field.ty; + // Parse variant-level json_schema attributes + let mut title: Option = None; + let mut description: Option = None; + for attr in variant_attrs { + if attr.path().is_ident("json_schema") { + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("title") { + title = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("description") { + description = Some(meta.value()?.parse::()?.value()); + } + Ok(()) + }); + } + } - let schema = type_to_json_schema(field_type, field_attrs); - quote! { - properties.insert( - #field_name.to_string(), - serde_json::Value::Object(#schema) - ); - } - }); + let title_quote = title.as_ref().map(|t| { + quote! { map.insert("title".to_string(), serde_json::Value::String(#t.to_string())); } + }); + let description_quote = description.as_ref().map(|desc| { + quote! { map.insert("description".to_string(), serde_json::Value::String(#desc.to_string())); } + }); - let required_fields = fields.iter().filter_map(|field| { - let renamed_field = renamed_field(&field.attrs); - let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + match &variant.fields { + Fields::Unit => { + // Unit variant: use "enum" with the variant name + quote! { + { + let mut map = serde_json::Map::new(); + map.insert("enum".to_string(), serde_json::Value::Array(vec![ + serde_json::Value::String(#renamed_variant.to_string()) + ])); + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } + Fields::Unnamed(fields) => { + // Newtype or tuple variant + if fields.unnamed.len() == 1 { + // Newtype variant: use the inner type's schema + let field = &fields.unnamed[0]; + let field_type = &field.ty; + let field_attrs = &field.attrs; + let schema = type_to_json_schema(field_type, field_attrs); + quote! { + { + let mut map = #schema; + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } else { + // Tuple variant: array with items + let field_schemas = fields.unnamed.iter().map(|field| { + let field_type = &field.ty; + let field_attrs = &field.attrs; + let schema = type_to_json_schema(field_type, field_attrs); + quote! { serde_json::Value::Object(#schema) } + }); + quote! { + { + let mut map = serde_json::Map::new(); + map.insert("type".to_string(), serde_json::Value::String("array".to_string())); + map.insert("items".to_string(), serde_json::Value::Array(vec![#(#field_schemas),*])); + map.insert("additionalItems".to_string(), serde_json::Value::Bool(false)); + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } + } + Fields::Named(fields) => { + // Struct variant: object with properties and required fields + let field_entries = fields.named.iter().map(|field| { + let field_attrs = &field.attrs; + let renamed_field = renamed_field(field_attrs); + let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + let field_type = &field.ty; - let field_type = &field.ty; - if !is_option(field_type) { - Some(quote! { - required.push(#field_name.to_string()); - }) - } else { - None - } - }); + let schema = type_to_json_schema(field_type, field_attrs); + quote! { + properties.insert( + #field_name.to_string(), + serde_json::Value::Object(#schema) + ); + } + }); - let expanded = quote! { - impl #name { - pub fn json_schema() -> serde_json::Map { - let mut schema = serde_json::Map::new(); - let mut properties = serde_json::Map::new(); - let mut required = Vec::new(); + let required_fields = fields.named.iter().filter_map(|field| { + let renamed_field = renamed_field(&field.attrs); + let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + + let field_type = &field.ty; + if !is_option(field_type) { + Some(quote! { + required.push(#field_name.to_string()); + }) + } else { + None + } + }); - #(#field_entries)* + quote! { + { + let mut map = serde_json::Map::new(); + let mut properties = serde_json::Map::new(); + let mut required = Vec::new(); - #(#required_fields)* + #(#field_entries)* - schema.insert("type".to_string(), serde_json::Value::String("object".to_string())); - schema.insert("properties".to_string(), serde_json::Value::Object(properties)); - if !required.is_empty() { - schema.insert("required".to_string(), serde_json::Value::Array( - required.into_iter().map(serde_json::Value::String).collect() - )); + #(#required_fields)* + + map.insert("type".to_string(), serde_json::Value::String("object".to_string())); + map.insert("properties".to_string(), serde_json::Value::Object(properties)); + if !required.is_empty() { + map.insert("required".to_string(), serde_json::Value::Array( + required.into_iter().map(serde_json::Value::String).collect() + )); + } + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } } + }); + quote! { + let mut schema = serde_json::Map::new(); + schema.insert("oneOf".to_string(), serde_json::Value::Array(vec![ + #(#variant_schemas),* + ])); schema } } + _ => panic!("JsonSchema derive macro only supports structs and enums"), + }; + + let expanded = quote! { + impl #name { + pub fn json_schema() -> serde_json::Map { + #schema_body + } + } }; TokenStream::from(expanded) } diff --git a/crates/rust-mcp-macros/src/utils.rs b/crates/rust-mcp-macros/src/utils.rs index 0d4bbed..71d3de3 100644 --- a/crates/rust-mcp-macros/src/utils.rs +++ b/crates/rust-mcp-macros/src/utils.rs @@ -1,5 +1,8 @@ use quote::quote; -use syn::{punctuated::Punctuated, token, Attribute, Path, PathArguments, Type}; +use syn::{ + punctuated::Punctuated, token, Attribute, DeriveInput, Lit, LitInt, LitStr, Path, + PathArguments, Type, +}; // Check if a type is an Option pub fn is_option(ty: &Type) -> bool { @@ -13,8 +16,8 @@ pub fn is_option(ty: &Type) -> bool { false } -// Check if a type is a Vec #[allow(unused)] +// Check if a type is a Vec pub fn is_vec(ty: &Type) -> bool { if let Type::Path(type_path) = ty { if type_path.path.segments.len() == 1 { @@ -26,8 +29,8 @@ pub fn is_vec(ty: &Type) -> bool { false } -// Extract the inner type from Vec or Option #[allow(unused)] +// Extract the inner type from Vec or Option pub fn inner_type(ty: &Type) -> Option<&Type> { if let Type::Path(type_path) = ty { if type_path.path.segments.len() == 1 { @@ -46,12 +49,11 @@ pub fn inner_type(ty: &Type) -> Option<&Type> { None } -fn doc_comment(attrs: &[Attribute]) -> Option { +pub fn doc_comment(attrs: &[Attribute]) -> Option { let mut docs = Vec::new(); for attr in attrs { if attr.path().is_ident("doc") { if let syn::Meta::NameValue(meta) = &attr.meta { - // Match value as Expr::Lit, then extract Lit::Str if let syn::Expr::Lit(expr_lit) = &meta.value { if let syn::Lit::Str(lit_str) = &expr_lit.lit { docs.push(lit_str.value().trim().to_string()); @@ -82,16 +84,143 @@ pub fn might_be_struct(ty: &Type) -> bool { false } +// Helper to check if a type is an enum +pub fn is_enum(ty: &Type, _input: &DeriveInput) -> bool { + if let Type::Path(type_path) = ty { + // Check for #[mcp_elicit(enum)] attribute on the type + // Since we can't access the enum's definition directly, we rely on the attribute + // This assumes the enum is marked with #[mcp_elicit(enum)] in its definition + // Alternatively, we could pass a list of known enums, but attribute-based is simpler + type_path + .path + .segments + .last() + .map(|s| { + // For now, we'll assume any type could be an enum if it has the attribute + // In a real-world scenario, we'd need to resolve the type's definition + // For simplicity, we check if the type name is plausible (not String, bool, i32, i64) + let ident = s.ident.to_string(); + !["String", "bool", "i32", "i64"].contains(&ident.as_str()) + }) + .unwrap_or(false) + } else { + false + } +} + +// Helper to generate enum parsing code +pub fn generate_enum_parse( + field_type: &Type, + field_name: &str, + base_crate: &proc_macro2::TokenStream, +) -> proc_macro2::TokenStream { + let type_ident = match field_type { + Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(), + _ => panic!("Expected path type for enum"), + }; + // Since we can't access the enum's variants directly in this context, + // we'll assume the enum has unit variants and expect strings matching their names + // In a real-world scenario, you'd parse the enum's Data::Enum to get variant names + // For now, we'll generate a generic parse assuming variant names are provided as strings + quote! { + { + // Attempt to parse the string using a match + // Since we don't have the variants, we rely on the enum implementing FromStr + match s.as_str() { + // We can't dynamically list variants, so we use FromStr + // If FromStr is not implemented, this will fail at compile time + s => s.parse().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( + "Invalid enum value for field '{}': cannot parse '{}' into {}", + #field_name, s, stringify!(#type_ident) + )))? + } + } + } +} + pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::TokenStream { - let number_types = [ - "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", "f32", "f64", + let integer_types = [ + "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", ]; - let doc_comment = doc_comment(attrs); - let description = doc_comment.as_ref().map(|desc| { + let float_types = ["f32", "f64"]; + + // Parse custom json_schema attributes + let mut title: Option = None; + let mut format: Option = None; + let mut min_length: Option = None; + let mut max_length: Option = None; + let mut minimum: Option = None; + let mut maximum: Option = None; + let mut default: Option = None; + let mut attr_description: Option = None; + + for attr in attrs { + if attr.path().is_ident("json_schema") { + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("title") { + title = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("description") { + attr_description = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("format") { + format = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("min_length") { + min_length = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("max_length") { + max_length = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("minimum") { + minimum = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("maximum") { + maximum = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("default") { + let lit = meta.value()?.parse::()?; + default = Some(match lit { + Lit::Str(lit_str) => { + let value = lit_str.value(); + quote! { serde_json::Value::String(#value.to_string()) } + } + Lit::Int(lit_int) => { + let value = lit_int.base10_parse::()?; + assert!( + (i64::MIN..=i64::MAX).contains(&value), + "Default value {value} out of range for i64" + ); + quote! { serde_json::Value::Number(serde_json::Number::from(#value)) } + } + Lit::Float(lit_float) => { + let value = lit_float.base10_parse::()?; + quote! { serde_json::Value::Number(serde_json::Number::from_f64(#value).expect("Invalid float")) } + } + Lit::Bool(lit_bool) => { + let value = lit_bool.value(); + quote! { serde_json::Value::Bool(#value) } + } + _ => return Err(meta.error("Unsupported default value type")), + }); + } + Ok(()) + }); + } + } + + let description = attr_description.or(doc_comment(attrs)); + let description_quote = description.as_ref().map(|desc| { quote! { map.insert("description".to_string(), serde_json::Value::String(#desc.to_string())); } }); + + let title_quote = title.as_ref().map(|t| { + quote! { + map.insert("title".to_string(), serde_json::Value::String(#t.to_string())); + } + }); + + let default_quote = default.as_ref().map(|d| { + quote! { + map.insert("default".to_string(), #d); + } + }); + match ty { Type::Path(type_path) => { if type_path.path.segments.len() == 1 { @@ -104,15 +233,43 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token if args.args.len() == 1 { if let syn::GenericArgument::Type(inner_ty) = &args.args[0] { let inner_schema = type_to_json_schema(inner_ty, attrs); + let format_quote = format.as_ref().map(|f| { + quote! { + map.insert("format".to_string(), serde_json::Value::String(#f.to_string())); + } + }); + let min_quote = min_length.as_ref().map(|min| { + quote! { + map.insert("minLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = max_length.as_ref().map(|max| { + quote! { + map.insert("maxLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); + let min_num_quote = minimum.as_ref().map(|min| { + quote! { + map.insert("minimum".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_num_quote = maximum.as_ref().map(|max| { + quote! { + map.insert("maximum".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { - let mut map = serde_json::Map::new(); - let inner_map = #inner_schema; - for (k, v) in inner_map { - map.insert(k, v); - } + let mut map = #inner_schema; map.insert("nullable".to_string(), serde_json::Value::Bool(true)); - #description + #description_quote + #title_quote + #format_quote + #min_quote + #max_quote + #min_num_quote + #max_num_quote + #default_quote map } }; @@ -126,12 +283,26 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token if args.args.len() == 1 { if let syn::GenericArgument::Type(inner_ty) = &args.args[0] { let inner_schema = type_to_json_schema(inner_ty, &[]); + let min_quote = min_length.as_ref().map(|min| { + quote! { + map.insert("minItems".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = max_length.as_ref().map(|max| { + quote! { + map.insert("maxItems".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("array".to_string())); map.insert("items".to_string(), serde_json::Value::Object(#inner_schema)); - #description + #description_quote + #title_quote + #min_quote + #max_quote + #default_quote map } }; @@ -144,36 +315,104 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token let path = &type_path.path; return quote! { { - let inner_schema = #path::json_schema(); - inner_schema + let mut map = #path::json_schema(); + #description_quote + #title_quote + #default_quote + map } }; } - // Handle basic types + // Handle String else if ident == "String" { + let format_quote = format.as_ref().map(|f| { + quote! { + map.insert("format".to_string(), serde_json::Value::String(#f.to_string())); + } + }); + let min_quote = min_length.as_ref().map(|min| { + quote! { + map.insert("minLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = max_length.as_ref().map(|max| { + quote! { + map.insert("maxLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("string".to_string())); - #description + #description_quote + #title_quote + #format_quote + #min_quote + #max_quote + #default_quote map } }; - } else if number_types.iter().any(|t| ident == t) { + } + // Handle integer types + else if integer_types.iter().any(|t| ident == t) { + let min_quote = minimum.as_ref().map(|min| { + quote! { + map.insert("minimum".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = maximum.as_ref().map(|max| { + quote! { + map.insert("maximum".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); + return quote! { + { + let mut map = serde_json::Map::new(); + map.insert("type".to_string(), serde_json::Value::String("integer".to_string())); + #description_quote + #title_quote + #min_quote + #max_quote + #default_quote + map + } + }; + } + // Handle float types + else if float_types.iter().any(|t| ident == t) { + let min_quote = minimum.as_ref().map(|min| { + quote! { + map.insert("minimum".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = maximum.as_ref().map(|max| { + quote! { + map.insert("maximum".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("number".to_string())); - #description + #description_quote + #title_quote + #min_quote + #max_quote + #default_quote map } }; - } else if ident == "bool" { + } + // Handle bool + else if ident == "bool" { return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("boolean".to_string())); - #description + #description_quote + #title_quote + #default_quote map } }; @@ -184,7 +423,9 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("unknown".to_string())); - #description + #description_quote + #title_quote + #default_quote map } } @@ -193,7 +434,9 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("unknown".to_string())); - #description + #description_quote + #title_quote + #default_quote map } }, @@ -204,7 +447,6 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token pub fn has_derive(attrs: &[Attribute], trait_name: &str) -> bool { attrs.iter().any(|attr| { if attr.path().is_ident("derive") { - // Parse the derive arguments as a comma-separated list of paths let parsed = attr.parse_args_with(Punctuated::::parse_terminated); if let Ok(derive_paths) = parsed { let derived = derive_paths.iter().any(|path| path.is_ident(trait_name)); @@ -220,7 +462,6 @@ pub fn renamed_field(attrs: &[Attribute]) -> Option { for attr in attrs { if attr.path().is_ident("serde") { - // Ignore other serde meta items (e.g., skip_serializing_if) let _ = attr.parse_nested_meta(|meta| { if meta.path.is_ident("rename") { if let Ok(lit) = meta.value() { @@ -493,12 +734,12 @@ mod tests { } #[test] - fn test_json_schema_number() { + fn test_json_schema_integer() { let ty: syn::Type = parse_quote!(i32); let tokens = type_to_json_schema(&ty, &[]); let output = render(tokens); assert!(output - .contains("\"type\".to_string(),serde_json::Value::String(\"number\".to_string())")); + .contains("\"type\".to_string(),serde_json::Value::String(\"integer\".to_string())")); } #[test] @@ -527,7 +768,7 @@ mod tests { let output = render(tokens); assert!(output.contains("\"nullable\".to_string(),serde_json::Value::Bool(true)")); assert!(output - .contains("\"type\".to_string(),serde_json::Value::String(\"number\".to_string())")); + .contains("\"type\".to_string(),serde_json::Value::String(\"integer\".to_string())")); } #[test] diff --git a/crates/rust-mcp-macros/tests/common/common.rs b/crates/rust-mcp-macros/tests/common/common.rs index 40c4e3c..d6bae2e 100644 --- a/crates/rust-mcp-macros/tests/common/common.rs +++ b/crates/rust-mcp-macros/tests/common/common.rs @@ -1,4 +1,7 @@ +use std::str::FromStr; + use rust_mcp_macros::JsonSchema; +use rust_mcp_schema::RpcError; #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] /// Represents a text replacement operation. @@ -26,3 +29,50 @@ pub struct EditFileTool { )] pub dry_run: Option, } + +#[derive(JsonSchema, Debug)] +pub enum Colors { + #[json_schema(title = "Green Color")] + Green, + #[json_schema(title = "Red Color")] + Red, +} + +impl FromStr for Colors { + type Err = RpcError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "green" => Ok(Colors::Green), + "red" => Ok(Colors::Red), + _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), + } + } +} + +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + + /// Email address of the user + #[json_schema(title = "Email", format = "email")] + pub email: Option, + + /// The user's age in years + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} diff --git a/crates/rust-mcp-macros/tests/macro_test.rs b/crates/rust-mcp-macros/tests/macro_test.rs index 3a23c87..4b6c926 100644 --- a/crates/rust-mcp-macros/tests/macro_test.rs +++ b/crates/rust-mcp-macros/tests/macro_test.rs @@ -1,4 +1,16 @@ +#[macro_use] +extern crate rust_mcp_macros; + +use std::collections::HashMap; + use common::EditOperation; +use rust_mcp_schema::{ + BooleanSchema, ElicitRequestedSchema, ElicitResultContentValue, EnumSchema, NumberSchema, + PrimitiveSchemaDefinition, StringSchema, StringSchemaFormat, +}; +use serde_json::json; + +use crate::common::{Colors, UserInfo}; #[path = "common/common.rs"] pub mod common; @@ -31,3 +43,232 @@ fn test_rename() { let properties = schema.get("properties").unwrap().as_object().unwrap(); assert_eq!(properties.len(), 2); } + +#[test] +fn test_attributes() { + #[derive(JsonSchema)] + struct User { + /// This is a fallback description from doc comment. + pub id: i32, + + #[json_schema( + title = "User Name", + description = "The user's full name (overrides doc)", + min_length = 1, + max_length = 100 + )] + pub name: String, + + #[json_schema( + title = "User Email", + format = "email", + min_length = 5, + max_length = 255 + )] + pub email: Option, + + #[json_schema( + title = "Tags", + description = "List of tags", + min_length = 0, + max_length = 10 + )] + pub tags: Vec, + } + + let schema = User::json_schema(); + let expected = json!({ + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "This is a fallback description from doc comment." + }, + "name": { + "type": "string", + "title": "User Name", + "description": "The user's full name (overrides doc)", + "minLength": 1, + "maxLength": 100 + }, + "email": { + "type": "string", + "title": "User Email", + "format": "email", + "minLength": 5, + "maxLength": 255, + "nullable": true + }, + "tags": { + "type": "array", + "items": { + "type": "string", + }, + "title": "Tags", + "description": "List of tags", + "minItems": 0, + "maxItems": 10 + } + }, + "required": ["id", "name", "tags"] + }); + + // Convert expected_value from serde_json::Value to serde_json::Map + let expected: serde_json::Map = + expected.as_object().expect("Expected JSON object").clone(); + + assert_eq!(schema, expected); +} + +#[test] +fn test_elicit_macro() { + assert_eq!(UserInfo::message(), "Please enter your info"); + + let requested_schema: ElicitRequestedSchema = UserInfo::requested_schema(); + assert_eq!( + requested_schema.required, + vec!["name", "age", "favorate_color"] + ); + + assert!(matches!( + requested_schema.properties.get("is_student").unwrap(), + PrimitiveSchemaDefinition::BooleanSchema(BooleanSchema { + default, + description, + title, + .. + }) + if + description.as_ref().unwrap() == "Is user a student?" && + title.as_ref().unwrap() == "Is student?" && + matches!(default, Some(true)) + + )); + + assert!(matches!( + requested_schema.properties.get("favorate_color").unwrap(), + PrimitiveSchemaDefinition::EnumSchema(EnumSchema { + description, + enum_, + enum_names, + title, + .. + }) + if description.as_ref().unwrap() == "User's favorite color" && + title.is_none() && + enum_.len()==2 && enum_.iter().all(|s| ["Green", "Red"].contains(&s.as_str())) && + enum_names.len()==2 && enum_names.iter().all(|s| ["Green Color", "Red Color"].contains(&s.as_str())) + )); + + assert!(matches!( + requested_schema.properties.get("age").unwrap(), + PrimitiveSchemaDefinition::NumberSchema(NumberSchema { + description, + maximum, + minimum, + title, + type_ + }) + if + description.as_ref().unwrap() == "The user's age in years" && + maximum.unwrap() == 125 && minimum.unwrap() == 15 && title.as_ref().unwrap() == "Age" + )); + + assert!(matches!( + requested_schema.properties.get("name").unwrap(), + PrimitiveSchemaDefinition::StringSchema(StringSchema { + description, + format, + max_length, + min_length, + title, + .. + }) + if format.is_none() && + description.as_ref().unwrap() == "The user's full name" && + max_length.unwrap() == 100 && min_length.unwrap() == 5 && title.as_ref().unwrap() == "Name" + )); + + assert!(matches!( + requested_schema.properties.get("email").unwrap(), + PrimitiveSchemaDefinition::StringSchema(StringSchema { + description, + format, + max_length, + min_length, + title, + .. + }) if matches!(format.unwrap(), StringSchemaFormat::Email) && + description.as_ref().unwrap() == "Email address of the user" && + max_length.is_none() && min_length.is_none() && title.as_ref().unwrap() == "Email" + )); + + let json_schema = &UserInfo::json_schema(); + + let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { + Some(arr) => arr + .iter() + .filter_map(|item| item.as_str().map(String::from)) + .collect(), + None => Vec::new(), + }; + + let properties: Option> = json_schema + .get("properties") + .and_then(|v| v.as_object()) // Safely extract "properties" as an object. + .map(|properties| { + properties + .iter() + .filter_map(|(key, value)| { + serde_json::to_value(value) + .ok() // If serialization fails, return None. + .and_then(|v| { + if let serde_json::Value::Object(obj) = v { + Some(obj) + } else { + None + } + }) + .map(|obj| (key.to_string(), PrimitiveSchemaDefinition::try_from(&obj))) + }) + .collect() + }); + + let properties = properties + .map(|map| { + map.into_iter() + .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple + .collect::, _>>() // collect only if all Ok + }) + .transpose() + .unwrap(); + + let properties = properties.expect("Was not able to create a ElicitRequestedSchema"); + + ElicitRequestedSchema::new(properties, required); +} + +#[test] +fn test_from_content_map() { + let mut content: ::std::collections::HashMap<::std::string::String, ElicitResultContentValue> = + HashMap::new(); + + content.extend([ + ( + "name".to_string(), + ElicitResultContentValue::String("Ali".to_string()), + ), + ( + "favorate_color".to_string(), + ElicitResultContentValue::String("Green".to_string()), + ), + ("age".to_string(), ElicitResultContentValue::Integer(15)), + ( + "is_student".to_string(), + ElicitResultContentValue::Boolean(false), + ), + ]); + + let u: UserInfo = UserInfo::from_content_map(Some(content)).unwrap(); + assert!(matches!(u.favorate_color, Colors::Green)); +} diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 51c3b49..2c70c3e 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -38,8 +38,8 @@ This project supports following transports: - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response - βœ… Streamable HTTP Support for MCP Clients -- ⬜ Resumability -- ⬜ Authentication / Oauth +- βœ… Resumability +- ⬜ Oauth Authentication **⚠️** Project is currently under development and should be used at your own risk. @@ -50,6 +50,7 @@ This project supports following transports: - [MCP Client (stdio)](#mcp-client-stdio) - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) +- [Macros](#macros) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) - [Security Considerations](#security-considerations) @@ -386,6 +387,114 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. +## Macros +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. + +> To use these macros, ensure the `macros` feature is enabled in your Cargo.toml. + +### mcp_tool +`mcp_tool` is a procedural macro attribute that helps generating rust_mcp_schema::Tool from a struct. + +Usage example: +```rust +#[mcp_tool( + name = "move_file", + title="Move File", + description = concat!("Move or rename files and directories. Can move files between directories ", +"and rename them in a single operation. If the destination exists, the ", +"operation will fail. Works across different directories and can be used ", +"for simple renaming within the same directory. ", +"Both source and destination must be within allowed directories."), + destructive_hint = false, + idempotent_hint = false, + open_world_hint = false, + read_only_hint = false +)] +#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] +pub struct MoveFileTool { + /// The source path of the file to move. + pub source: String, + /// The destination path to move the file to. + pub destination: String, +} + +// Now we can call `tool()` method on it to get a Tool instance +let rust_mcp_sdk::schema::Tool = MoveFileTool::tool(); + +``` + +πŸ’» For a real-world example, check out any of the tools available at: https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools + + +### tool_box +`tool_box` generates an enum from a provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. + +It accepts an array of tools and generates an enum where each tool becomes a variant of the enum. + +Generated enum has a `tools()` function that returns a `Vec` , and a `TryFrom` trait implementation that could be used to convert a ToolRequest into a Tool instance. + +Usage example: +```rust + // Accepts an array of tools and generates an enum named `FileSystemTools`, + // where each tool becomes a variant of the enum. + tool_box!(FileSystemTools, [ReadFileTool, MoveFileTool, SearchFilesTool]); + + // now in the app, we can use the FileSystemTools, like: + let all_tools: Vec = FileSystemTools::tools(); +``` + +πŸ’» To see a real-world example of that please see : +- `tool_box` macro usage: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs) +- using `tools()` in list tools request : [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L67) +- using `try_from` in call tool_request: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L100) + + + +### mcp_elicit +The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: + +- A `message()` method returning the elicitation message as a string. +- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. +- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. + +### Attributes + +- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. + +Usage example: +```rust +// A struct that could be used to send elicit request and get the input from the user +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} + +// send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance +let result: ElicitResult = server + .elicit_input(UserInfo::message(), UserInfo::requested_schema()) + .await?; + +// Create a UserInfo instance using data provided by the user on the client side +let user_info = UserInfo::from_content_map(result.content)?; + +``` + +πŸ’» For mre info please see : +- https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros + ## Getting Started If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) @@ -509,6 +618,7 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. + #### MCP Protocol Versions with Corresponding Features - `2025_06_18` : Activates MCP Protocol version 2025-06-18 (enabled by default) @@ -621,6 +731,10 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na + + + + ## Contributing We welcome everyone who wishes to contribute! Please refer to the [contributing](CONTRIBUTING.md) guidelines for more details. diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index e7f8793..f96b261 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -3,7 +3,9 @@ use std::{sync::Arc, time::Duration}; use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; + use rust_mcp_transport::event_store::EventStore; + use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index dc860b6..da087d1 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -4,9 +4,10 @@ use crate::schema::{ ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, - GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, - ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, - ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification, + ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest, + Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest, + ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams, + ListRootsResult, ListToolsRequest, LoggingMessageNotification, LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification, PromptListChangedNotificationParams, ReadResourceRequest, RequestId, ResourceListChangedNotification, ResourceListChangedNotificationParams, @@ -58,6 +59,23 @@ pub trait McpServer: Sync + Send { &self.server_info().capabilities } + /// Sends an elicitation request to the client to prompt user input and returns the received response. + /// + /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema. + /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only + async fn elicit_input( + &self, + message: String, + requested_schema: ElicitRequestedSchema, + ) -> SdkResult { + let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams { + message, + requested_schema, + }); + let response = self.request(request.into(), None).await?; + ElicitResult::try_from(response).map_err(|err| err.into()) + } + /// Sends a request to the client and processes the response. /// /// This function sends a `RequestFromServer` message to the client, waits for the response, diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 6b78895..d6b45f7 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -404,7 +404,7 @@ pub mod sample_tools { tokio::time::sleep(Duration::from_millis(self.interval)).await; } - let message = format!("so many messages sent"); + let message = "so many messages sent".to_string(); Ok(CallToolResult::text_content(vec![TextContent::from( message, )])) diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index 1d273e5..ceb778a 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -309,7 +309,7 @@ async fn should_handle_successful_initial_get_connection_for_sse() { // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; // let mut body = String::new(); - body.push_str(&"data: Connection established\n\n".to_string()); + body.push_str("data: Connection established\n\n"); let response = ResponseTemplate::new(200) .set_body_raw(body.into_bytes(), "text/event-stream") @@ -428,7 +428,7 @@ async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; // let mut body = String::new(); - body.push_str(&"data: Connection established\n\n".to_string()); + body.push_str("data: Connection established\n\n"); let response = ResponseTemplate::new(405) .set_body_raw(body.into_bytes(), "text/event-stream") diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index af2dce6..79c9f00 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -302,7 +302,7 @@ async fn get_standalone_stream( headers.insert("mcp-session-id", session_id); headers.insert("mcp-protocol-version", "2025-03-26"); - if let Some(last_event_id) = last_event_id.clone() { + if let Some(last_event_id) = last_event_id { headers.insert("last-event-id", last_event_id); } diff --git a/examples/hello-world-mcp-server-stdio/src/tools.rs b/examples/hello-world-mcp-server-stdio/src/tools.rs index 15d6a8b..f6b1edb 100644 --- a/examples/hello-world-mcp-server-stdio/src/tools.rs +++ b/examples/hello-world-mcp-server-stdio/src/tools.rs @@ -1,8 +1,29 @@ use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult, TextContent}; -use rust_mcp_sdk::{ - macros::{mcp_tool, JsonSchema}, - tool_box, -}; +use rust_mcp_sdk::{macros::mcp_tool, tool_box}; + +use rust_mcp_sdk::macros::JsonSchema; +use rust_mcp_sdk::schema::RpcError; +use std::str::FromStr; + +// Simple enum with FromStr trait implemented +#[derive(JsonSchema, Debug)] +pub enum Colors { + #[json_schema(title = "Green Color")] + Green, + #[json_schema(title = "Red Color")] + Red, +} +impl FromStr for Colors { + type Err = RpcError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "green" => Ok(Colors::Green), + "red" => Ok(Colors::Red), + _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), + } + } +} //****************// // SayHelloTool // From 3528fe2eeaecff14745073b418c0f7145187b9b9 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 18 Sep 2025 19:36:19 -0300 Subject: [PATCH 30/33] Update getting-started-mcp-server.md --- doc/getting-started-mcp-server.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 6fac258..9b5f61d 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -40,7 +40,7 @@ edition = "2024" [dependencies] async-trait = "0.1" -rust-mcp-sdk = "0.4" +rust-mcp-sdk = "0.7" serde = "1.0" serde_json = "1.0" tokio = "1.4" From d6a57b4f8d353a7a3adf1db0b82641c934b2dc9a Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 18 Sep 2025 19:46:11 -0300 Subject: [PATCH 31/33] Update getting-started-mcp-server.md --- doc/getting-started-mcp-server.md | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 9b5f61d..418fd66 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -72,11 +72,10 @@ Create a new module in the project called `tools.rs` and include the definitions //src/tools.rs use rust_mcp_sdk::schema::{CallToolResult, TextContent, schema_utils::CallToolError}; use rust_mcp_sdk::{ - macros::{mcp_tool, JsonSchema}, + macros::{JsonSchema, mcp_tool}, tool_box, }; - //****************// // SayHelloTool // //****************// @@ -93,7 +92,9 @@ pub struct SayHelloTool { impl SayHelloTool { pub fn call_tool(&self) -> Result { let hello_message = format!("Hello, {}!", self.name); - Ok(CallToolResult::text_content( vec![TextContent::from(hello_message)] )) + Ok(CallToolResult::text_content(vec![TextContent::from( + hello_message, + )])) } } @@ -112,7 +113,9 @@ pub struct SayGoodbyeTool { impl SayGoodbyeTool { pub fn call_tool(&self) -> Result { let hello_message = format!("Goodbye, {}!", self.name); - Ok(CallToolResult::text_content( vec![TextContent::from(hello_message)] )) + Ok(CallToolResult::text_content(vec![TextContent::from( + hello_message, + )])) } } @@ -142,12 +145,14 @@ Here is the code for `handler.rs` : ```rs // src/handler.rs +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ - schema_utils::CallToolError, CallToolRequest, CallToolResult, RpcError, - ListToolsRequest, ListToolsResult, + CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, RpcError, + schema_utils::CallToolError, }; -use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use rust_mcp_sdk::{McpServer, mcp_server::ServerHandler}; use crate::tools::GreetingTools; @@ -207,14 +212,11 @@ mod handler; mod tools; use handler::MyServerHandler; use rust_mcp_sdk::schema::{ - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, - LATEST_PROTOCOL_VERSION, + Implementation, InitializeResult, LATEST_PROTOCOL_VERSION, ServerCapabilities, + ServerCapabilitiesTools, }; - use rust_mcp_sdk::{ - error::SdkResult, - mcp_server::{server_runtime, ServerRuntime}, - McpServer, StdioTransport, TransportOptions, + McpServer, StdioTransport, TransportOptions, error::SdkResult, mcp_server::server_runtime, }; #[tokio::main] @@ -244,7 +246,7 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; //create the MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server = server_runtime::create_server(server_details, transport, handler); // Start the server server.start().await From 54b662cdccaf7a784be07bcba8eb7f07e99d95a5 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Thu, 18 Sep 2025 21:16:45 -0300 Subject: [PATCH 32/33] Update main.rs --- examples/simple-mcp-client-streamable-http/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple-mcp-client-streamable-http/src/main.rs b/examples/simple-mcp-client-streamable-http/src/main.rs index ab580db..95d4d8d 100644 --- a/examples/simple-mcp-client-streamable-http/src/main.rs +++ b/examples/simple-mcp-client-streamable-http/src/main.rs @@ -16,7 +16,7 @@ use tracing_subscriber::util::SubscriberInitExt; use crate::inquiry_utils::InquiryUtils; -const MCP_SERVER_URL: &str = "http://127.0.0.1:8080/mcp"; +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/mcp"; #[tokio::main] async fn main() -> SdkResult<()> { From 7083e543184b9d74ab4e9d8ac3b6f40ea8177085 Mon Sep 17 00:00:00 2001 From: Ali Hashemi <14126952+hashemix@users.noreply.github.com> Date: Fri, 19 Sep 2025 07:26:45 -0300 Subject: [PATCH 33/33] chore: release main (#100) * chore: release main * chore: update Cargo.toml for release --------- Co-authored-by: github-actions[bot] --- .release-manifest.json | 26 +++++++++---------- Cargo.lock | 26 +++++++++---------- Cargo.toml | 4 +-- crates/rust-mcp-macros/CHANGELOG.md | 7 +++++ crates/rust-mcp-macros/Cargo.toml | 2 +- crates/rust-mcp-sdk/CHANGELOG.md | 16 ++++++++++++ crates/rust-mcp-sdk/Cargo.toml | 2 +- crates/rust-mcp-transport/CHANGELOG.md | 19 ++++++++++++++ crates/rust-mcp-transport/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../hello-world-mcp-server-stdio/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../simple-mcp-client-sse-core/Cargo.toml | 2 +- examples/simple-mcp-client-sse/Cargo.toml | 2 +- .../simple-mcp-client-stdio-core/Cargo.toml | 2 +- examples/simple-mcp-client-stdio/Cargo.toml | 2 +- .../Cargo.toml | 2 +- .../Cargo.toml | 2 +- 19 files changed, 83 insertions(+), 41 deletions(-) diff --git a/.release-manifest.json b/.release-manifest.json index a645da6..db381e1 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,15 +1,15 @@ { - "crates/rust-mcp-sdk": "0.6.3", - "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.5.0", - "examples/hello-world-mcp-server-stdio": "0.1.28", - "examples/hello-world-mcp-server-stdio-core": "0.1.19", - "examples/simple-mcp-client-stdio": "0.1.28", - "examples/simple-mcp-client-stdio-core": "0.1.28", - "examples/hello-world-server-streamable-http-core": "0.1.19", - "examples/hello-world-server-streamable-http": "0.1.28", - "examples/simple-mcp-client-sse-core": "0.1.19", - "examples/simple-mcp-client-sse": "0.1.19", - "examples/simple-mcp-client-streamable-http": "0.1.0", - "examples/simple-mcp-client-streamable-http-core": "0.1.0" + "crates/rust-mcp-sdk": "0.7.0", + "crates/rust-mcp-macros": "0.5.2", + "crates/rust-mcp-transport": "0.6.0", + "examples/hello-world-mcp-server-stdio": "0.1.29", + "examples/hello-world-mcp-server-stdio-core": "0.1.20", + "examples/simple-mcp-client-stdio": "0.1.29", + "examples/simple-mcp-client-stdio-core": "0.1.29", + "examples/hello-world-server-streamable-http-core": "0.1.20", + "examples/hello-world-server-streamable-http": "0.1.32", + "examples/simple-mcp-client-sse-core": "0.1.20", + "examples/simple-mcp-client-sse": "0.1.23", + "examples/simple-mcp-client-streamable-http": "0.1.1", + "examples/simple-mcp-client-streamable-http-core": "0.1.1" } diff --git a/Cargo.lock b/Cargo.lock index 6ee3950..0acb30d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -682,7 +682,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "hello-world-mcp-server-stdio" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "futures", @@ -696,7 +696,7 @@ dependencies = [ [[package]] name = "hello-world-mcp-server-stdio-core" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "futures", @@ -708,7 +708,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http" -version = "0.1.31" +version = "0.1.32" dependencies = [ "async-trait", "futures", @@ -722,7 +722,7 @@ dependencies = [ [[package]] name = "hello-world-server-streamable-http-core" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "futures", @@ -1630,7 +1630,7 @@ dependencies = [ [[package]] name = "rust-mcp-macros" -version = "0.5.1" +version = "0.5.2" dependencies = [ "proc-macro2", "quote", @@ -1652,7 +1652,7 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.6.3" +version = "0.7.0" dependencies = [ "async-trait", "axum", @@ -1677,7 +1677,7 @@ dependencies = [ [[package]] name = "rust-mcp-transport" -version = "0.5.1" +version = "0.6.0" dependencies = [ "async-trait", "bytes", @@ -1872,7 +1872,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse" -version = "0.1.22" +version = "0.1.23" dependencies = [ "async-trait", "colored", @@ -1888,7 +1888,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-sse-core" -version = "0.1.19" +version = "0.1.20" dependencies = [ "async-trait", "colored", @@ -1904,7 +1904,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-stdio" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "colored", @@ -1918,7 +1918,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-stdio-core" -version = "0.1.28" +version = "0.1.29" dependencies = [ "async-trait", "colored", @@ -1932,7 +1932,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-streamable-http" -version = "0.1.0" +version = "0.1.1" dependencies = [ "async-trait", "colored", @@ -1948,7 +1948,7 @@ dependencies = [ [[package]] name = "simple-mcp-client-streamable-http-core" -version = "0.1.0" +version = "0.1.1" dependencies = [ "async-trait", "colored", diff --git a/Cargo.toml b/Cargo.toml index 711204d..edb7e28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,9 +19,9 @@ members = [ [workspace.dependencies] # Workspace member crates -rust-mcp-transport = { version = "0.5.1", path = "crates/rust-mcp-transport", default-features = false } +rust-mcp-transport = { version = "0.6.0", path = "crates/rust-mcp-transport", default-features = false } rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false } -rust-mcp-macros = { version = "0.5.1", path = "crates/rust-mcp-macros", default-features = false } +rust-mcp-macros = { version = "0.5.2", path = "crates/rust-mcp-macros", default-features = false } # External crates rust-mcp-schema = { version = "0.7", default-features = false } diff --git a/crates/rust-mcp-macros/CHANGELOG.md b/crates/rust-mcp-macros/CHANGELOG.md index a7b5306..69b3059 100644 --- a/crates/rust-mcp-macros/CHANGELOG.md +++ b/crates/rust-mcp-macros/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.5.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-macros-v0.5.1...rust-mcp-macros-v0.5.2) (2025-09-19) + + +### πŸš€ Features + +* Add elicitation macros and add elicit_input() method ([#99](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/99)) ([3ab5fe7](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/3ab5fe73aaa10de2b5b23caee357ac15b37c845f)) + ## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-macros-v0.5.0...rust-mcp-macros-v0.5.1) (2025-08-12) diff --git a/crates/rust-mcp-macros/Cargo.toml b/crates/rust-mcp-macros/Cargo.toml index 0dfdc56..9c2dd5a 100644 --- a/crates/rust-mcp-macros/Cargo.toml +++ b/crates/rust-mcp-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-macros" -version = "0.5.1" +version = "0.5.2" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "A procedural macro that derives the MCPToolSchema implementation for structs or enums, generating a tool_input_schema function used with rust_mcp_schema::Tool." diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index db5a72b..4fde908 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.7.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.3...rust-mcp-sdk-v0.7.0) (2025-09-19) + + +### ⚠ BREAKING CHANGES + +* add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) +* update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) + +### πŸš€ Features + +* Add elicitation macros and add elicit_input() method ([#99](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/99)) ([3ab5fe7](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/3ab5fe73aaa10de2b5b23caee357ac15b37c845f)) +* Add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) ([abb0c36](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/abb0c36126b0a397bc20a1de36c5a5a80924a01e)) +* Add tls-no-provider feature ([#97](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/97)) ([5dacceb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/5dacceb0c2d18b8334744a13d438c6916bb7244c)) +* Event store support for resumability ([#101](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/101)) ([08742bb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/08742bb9636f81ee79eda4edc192b3b8ed4c7287)) +* Update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) ([a2d6d23](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/a2d6d23ab59fbc34d04526e2606f747f93a8468c)) + ## [0.6.3](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.2...rust-mcp-sdk-v0.6.3) (2025-08-31) ## [0.6.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.1...rust-mcp-sdk-v0.6.2) (2025-08-30) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 99d6f86..8bba7c7 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.6.3" +version = "0.7.0" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." diff --git a/crates/rust-mcp-transport/CHANGELOG.md b/crates/rust-mcp-transport/CHANGELOG.md index 9a0d2e1..2d692b4 100644 --- a/crates/rust-mcp-transport/CHANGELOG.md +++ b/crates/rust-mcp-transport/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## [0.6.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.5.0...rust-mcp-transport-v0.6.0) (2025-09-19) + + +### ⚠ BREAKING CHANGES + +* add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) +* update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) + +### πŸš€ Features + +* Add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) ([abb0c36](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/abb0c36126b0a397bc20a1de36c5a5a80924a01e)) +* Event store support for resumability ([#101](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/101)) ([08742bb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/08742bb9636f81ee79eda4edc192b3b8ed4c7287)) +* Update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) ([a2d6d23](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/a2d6d23ab59fbc34d04526e2606f747f93a8468c)) + + +### πŸ› Bug Fixes + +* Correct pending_requests instance ([#94](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/94)) ([9d8c1fb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/9d8c1fbdf3ddb7c67ce1fb7dcb8e50b8ba2e1202)) + ## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.5.0...rust-mcp-transport-v0.5.1) (2025-08-31) diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index 2f03580..8331eaf 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-transport" -version = "0.5.1" +version = "0.6.0" authors = ["Ali Hashemi"] categories = ["data-structures"] description = "Transport implementations for the MCP (Model Context Protocol) within the rust-mcp-sdk ecosystem, enabling asynchronous data exchange and efficient message handling between MCP clients and servers." diff --git a/examples/hello-world-mcp-server-stdio-core/Cargo.toml b/examples/hello-world-mcp-server-stdio-core/Cargo.toml index 14eb904..f37d4c4 100644 --- a/examples/hello-world-mcp-server-stdio-core/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-stdio-core" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-mcp-server-stdio/Cargo.toml b/examples/hello-world-mcp-server-stdio/Cargo.toml index 9d15be3..1947dce 100644 --- a/examples/hello-world-mcp-server-stdio/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-mcp-server-stdio" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http-core/Cargo.toml b/examples/hello-world-server-streamable-http-core/Cargo.toml index a762058..85e470a 100644 --- a/examples/hello-world-server-streamable-http-core/Cargo.toml +++ b/examples/hello-world-server-streamable-http-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http-core" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 17a87c8..61d080f 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.31" +version = "0.1.32" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse-core/Cargo.toml b/examples/simple-mcp-client-sse-core/Cargo.toml index 25dcd7d..05654fc 100644 --- a/examples/simple-mcp-client-sse-core/Cargo.toml +++ b/examples/simple-mcp-client-sse-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse-core" -version = "0.1.19" +version = "0.1.20" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index bf7174d..0720afe 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.22" +version = "0.1.23" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-stdio-core/Cargo.toml b/examples/simple-mcp-client-stdio-core/Cargo.toml index 6d95cf6..f7dc568 100644 --- a/examples/simple-mcp-client-stdio-core/Cargo.toml +++ b/examples/simple-mcp-client-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-stdio-core" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-stdio/Cargo.toml b/examples/simple-mcp-client-stdio/Cargo.toml index 3597105..7bbd890 100644 --- a/examples/simple-mcp-client-stdio/Cargo.toml +++ b/examples/simple-mcp-client-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-stdio" -version = "0.1.28" +version = "0.1.29" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-streamable-http-core/Cargo.toml b/examples/simple-mcp-client-streamable-http-core/Cargo.toml index 68356e1..c8b3464 100644 --- a/examples/simple-mcp-client-streamable-http-core/Cargo.toml +++ b/examples/simple-mcp-client-streamable-http-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-streamable-http-core" -version = "0.1.0" +version = "0.1.1" edition = "2021" publish = false license = "MIT" diff --git a/examples/simple-mcp-client-streamable-http/Cargo.toml b/examples/simple-mcp-client-streamable-http/Cargo.toml index 0638aab..bf2827a 100644 --- a/examples/simple-mcp-client-streamable-http/Cargo.toml +++ b/examples/simple-mcp-client-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-streamable-http" -version = "0.1.0" +version = "0.1.1" edition = "2021" publish = false license = "MIT"