From 927b5c67aa14300dfaad8b8995537f5292612d95 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 10 Oct 2025 15:56:29 -0300 Subject: [PATCH 01/14] chore: refactor hyperserver --- Cargo.lock | 4 ++++ crates/rust-mcp-sdk/Cargo.toml | 4 ++++ crates/rust-mcp-sdk/src/hyper_servers.rs | 3 --- .../src/hyper_servers/hyper_runtime.rs | 2 +- .../middlewares/protect_dns_rebinding.rs | 3 +-- .../hyper_servers/middlewares/session_id_gen.rs | 6 ++---- crates/rust-mcp-sdk/src/hyper_servers/routes.rs | 5 +++-- .../src/hyper_servers/routes/messages_routes.rs | 6 ++---- .../src/hyper_servers/routes/sse_routes.rs | 2 +- .../routes/streamable_http_routes.rs | 15 ++++++--------- crates/rust-mcp-sdk/src/hyper_servers/server.rs | 3 +-- crates/rust-mcp-sdk/src/lib.rs | 1 + crates/rust-mcp-sdk/src/mcp_http.rs | 9 +++++++++ .../src/{hyper_servers => mcp_http}/app_state.rs | 2 +- .../rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs | 1 + .../hyper_utils.rs => mcp_http/mcp_http_utils.rs} | 6 ++---- .../{hyper_servers => mcp_http}/session_store.rs | 0 .../session_store/in_memory.rs | 0 crates/rust-mcp-sdk/src/utils.rs | 3 +-- 19 files changed, 40 insertions(+), 35 deletions(-) create mode 100644 crates/rust-mcp-sdk/src/mcp_http.rs rename crates/rust-mcp-sdk/src/{hyper_servers => mcp_http}/app_state.rs (97%) create mode 100644 crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs rename crates/rust-mcp-sdk/src/{hyper_servers/routes/hyper_utils.rs => mcp_http/mcp_http_utils.rs} (99%) rename crates/rust-mcp-sdk/src/{hyper_servers => mcp_http}/session_store.rs (100%) rename crates/rust-mcp-sdk/src/{hyper_servers => mcp_http}/session_store/in_memory.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 0acb30d..51d9cb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1658,7 +1658,11 @@ dependencies = [ "axum", "axum-server", "base64 0.22.1", + "bytes", "futures", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", "hyper 1.7.0", "reqwest", "rust-mcp-macros", diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 8bba7c7..85371fe 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -32,6 +32,10 @@ base64.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } +http = "1.3.1" +http-body-util = "0.1.3" +http-body = "1.0.1" +bytes.workspace = true [dev-dependencies] wiremock = "0.5" diff --git a/crates/rust-mcp-sdk/src/hyper_servers.rs b/crates/rust-mcp-sdk/src/hyper_servers.rs index f18c428..318720e 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers.rs @@ -1,4 +1,3 @@ -mod app_state; pub mod error; pub mod hyper_runtime; pub mod hyper_server; @@ -6,7 +5,5 @@ pub mod hyper_server_core; mod middlewares; mod routes; mod server; -mod session_store; pub use server::*; -pub use session_store::*; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 85cf791..59c5b4d 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -1,6 +1,7 @@ use std::{sync::Arc, time::Duration}; use crate::{ + mcp_http::AppState, mcp_server::HyperServer, schema::{ schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, @@ -18,7 +19,6 @@ use tokio::{sync::Mutex, task::JoinHandle}; use crate::{ error::SdkResult, - hyper_servers::app_state::AppState, mcp_server::{ error::{TransportServerError, TransportServerResult}, ServerRuntime, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs index 5674e87..1fdcfef 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs @@ -1,5 +1,4 @@ -use crate::hyper_servers::app_state::AppState; -use crate::schema::schema_utils::SdkError; +use crate::{mcp_http::AppState, schema::schema_utils::SdkError}; use axum::{ extract::{Request, State}, middleware::Next, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs index b68b325..7611b3c 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs @@ -1,5 +1,4 @@ -use std::sync::Arc; - +use crate::mcp_http::AppState; use axum::{ extract::{Request, State}, middleware::Next, @@ -7,8 +6,7 @@ use axum::{ }; use hyper::StatusCode; use rust_mcp_transport::SessionId; - -use crate::hyper_servers::app_state::AppState; +use std::sync::Arc; // Middleware to generate and attach a session ID pub async fn generate_session_id( diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index b1b15fc..90844c6 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -1,10 +1,11 @@ pub mod fallback_routes; -mod hyper_utils; pub mod messages_routes; pub mod sse_routes; pub mod streamable_http_routes; -use super::{app_state::AppState, HyperServerOptions}; +use crate::mcp_http::AppState; + +use super::HyperServerOptions; use axum::Router; use std::sync::Arc; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 44b671f..bce49e1 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -1,8 +1,6 @@ use crate::{ - hyper_servers::{ - app_state::AppState, - error::{TransportServerError, TransportServerResult}, - }, + hyper_servers::error::{TransportServerError, TransportServerResult}, + mcp_http::AppState, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, utils::remove_query_and_hash, }; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 27a16b2..220b6f4 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,8 +1,8 @@ +use crate::mcp_http::AppState; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ hyper_servers::{ - app_state::AppState, error::TransportServerResult, middlewares::{ protect_dns_rebinding::protect_dns_rebinding, session_id_gen::generate_session_id, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 67f8679..9a19470 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,16 +1,13 @@ -use super::hyper_utils::start_new_session; +use crate::mcp_http::{ + acceptable_content_type, accepts_event_stream, create_standalone_stream, delete_session, + process_incoming_message, process_incoming_message_return, start_new_session, + valid_streaming_http_accept_header, validate_mcp_protocol_version_header, AppState, +}; use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, hyper_servers::{ - app_state::AppState, - error::TransportServerResult, - middlewares::protect_dns_rebinding::protect_dns_rebinding, - routes::hyper_utils::{ - acceptable_content_type, accepts_event_stream, create_standalone_stream, - delete_session, process_incoming_message, process_incoming_message_return, - valid_streaming_http_accept_header, validate_mcp_protocol_version_header, - }, + error::TransportServerResult, middlewares::protect_dns_rebinding::protect_dns_rebinding, }, utils::valid_initialize_method, }; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 71bccee..4fbd8ad 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,6 +1,7 @@ use crate::{ error::SdkResult, id_generator::{FastIdGenerator, UuidGenerator}, + mcp_http::{AppState, InMemorySessionStore}, mcp_server::hyper_runtime::HyperRuntime, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; @@ -16,10 +17,8 @@ use std::{ use tokio::signal; use super::{ - app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - InMemorySessionStore, }; use crate::schema::InitializeResult; use axum::Router; diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index a33f889..70480c0 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -2,6 +2,7 @@ pub mod error; #[cfg(feature = "hyper-server")] mod hyper_servers; mod mcp_handlers; +pub(crate) mod mcp_http; mod mcp_macros; mod mcp_runtimes; mod mcp_traits; diff --git a/crates/rust-mcp-sdk/src/mcp_http.rs b/crates/rust-mcp-sdk/src/mcp_http.rs new file mode 100644 index 0000000..40038fd --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http.rs @@ -0,0 +1,9 @@ +mod app_state; +mod mcp_http_handler; +mod mcp_http_utils; +mod session_store; + +pub(crate) use app_state::*; +pub use mcp_http_handler::*; +pub(crate) use mcp_http_utils::*; +pub use session_store::*; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs similarity index 97% rename from crates/rust-mcp-sdk/src/hyper_servers/app_state.rs rename to crates/rust-mcp-sdk/src/mcp_http/app_state.rs index f96b261..9553bff 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs @@ -8,7 +8,7 @@ use rust_mcp_transport::event_store::EventStore; use rust_mcp_transport::{SessionId, TransportOptions}; -/// Application state struct for the Hyper server +/// Application state struct for the Hyper ser /// /// Holds shared, thread-safe references to session storage, ID generator, /// server details, handler, ping interval, and transport options. diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -0,0 +1 @@ + diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs similarity index 99% rename from crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs rename to crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index 7101a73..d1a3594 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -1,9 +1,7 @@ use crate::{ error::SdkResult, - hyper_servers::{ - app_state::AppState, - error::{TransportServerError, TransportServerResult}, - }, + hyper_servers::error::{TransportServerError, TransportServerResult}, + mcp_http::AppState, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/mcp_http/session_store.rs similarity index 100% rename from crates/rust-mcp-sdk/src/hyper_servers/session_store.rs rename to crates/rust-mcp-sdk/src/mcp_http/session_store.rs diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs b/crates/rust-mcp-sdk/src/mcp_http/session_store/in_memory.rs similarity index 100% rename from crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs rename to crates/rust-mcp-sdk/src/mcp_http/session_store/in_memory.rs diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 16fe7c7..2d80f1e 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,6 +1,5 @@ -use crate::schema::schema_utils::{ClientMessages, SdkError}; - use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; +use crate::schema::schema_utils::{ClientMessages, SdkError}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; From 91011e81d797ff67f65f9ec0b5af7c5626711e78 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 10 Oct 2025 16:07:57 -0300 Subject: [PATCH 02/14] chore: refactor utils --- .../src/hyper_servers/routes/streamable_http_routes.rs | 10 +++++++--- crates/rust-mcp-sdk/src/mcp_http.rs | 7 +++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 9a19470..696204b 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,9 +1,11 @@ -use crate::mcp_http::{ +use crate::mcp_http::utils::{ acceptable_content_type, accepts_event_stream, create_standalone_stream, delete_session, process_incoming_message, process_incoming_message_return, start_new_session, - valid_streaming_http_accept_header, validate_mcp_protocol_version_header, AppState, + valid_streaming_http_accept_header, validate_mcp_protocol_version_header, }; +use crate::mcp_http::AppState; use crate::schema::schema_utils::SdkError; +use crate::utils::validate_mcp_protocol_version; use crate::{ error::McpSdkError, hyper_servers::{ @@ -20,7 +22,9 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; +use rust_mcp_transport::{ + SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, +}; use std::{collections::HashMap, sync::Arc}; pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { diff --git a/crates/rust-mcp-sdk/src/mcp_http.rs b/crates/rust-mcp-sdk/src/mcp_http.rs index 40038fd..59cfedc 100644 --- a/crates/rust-mcp-sdk/src/mcp_http.rs +++ b/crates/rust-mcp-sdk/src/mcp_http.rs @@ -1,9 +1,12 @@ mod app_state; mod mcp_http_handler; -mod mcp_http_utils; +pub(crate) mod mcp_http_utils; mod session_store; pub(crate) use app_state::*; pub use mcp_http_handler::*; -pub(crate) use mcp_http_utils::*; pub use session_store::*; + +pub(crate) mod utils { + pub use super::mcp_http_utils::*; +} From 49a15cdee74f06c25f448b8edd096280e9aaa3e9 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 10 Oct 2025 16:30:11 -0300 Subject: [PATCH 03/14] chore: use proper consts for header names --- crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index d1a3594..6ff48ec 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -19,6 +19,7 @@ use axum::{ Json, }; use futures::stream; +use http::header::ACCEPT; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, @@ -477,7 +478,7 @@ pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<() pub fn accepts_event_stream(headers: &HeaderMap) -> bool { let accept_header = headers - .get("accept") + .get(ACCEPT) .and_then(|val| val.to_str().ok()) .unwrap_or(""); @@ -488,7 +489,7 @@ pub fn accepts_event_stream(headers: &HeaderMap) -> bool { pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { let accept_header = headers - .get("accept") + .get(ACCEPT) .and_then(|val| val.to_str().ok()) .unwrap_or(""); From dcfcc308ee3bd638c62375e33a318e99b45397d6 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 10 Oct 2025 19:03:24 -0300 Subject: [PATCH 04/14] chore: clean up app state --- .../rust-mcp-sdk/src/hyper_servers/routes.rs | 5 +++- .../hyper_servers/routes/messages_routes.rs | 4 +-- .../src/hyper_servers/routes/sse_routes.rs | 25 ++++++++++++++++--- .../rust-mcp-sdk/src/hyper_servers/server.rs | 2 -- crates/rust-mcp-sdk/src/mcp_http/app_state.rs | 2 -- .../src/mcp_http/mcp_http_utils.rs | 18 +++++++++++-- 6 files changed, 44 insertions(+), 12 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index 90844c6..ae633d8 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -33,8 +33,11 @@ pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> .merge(sse_routes::routes( state.clone(), server_options.sse_endpoint(), + server_options.sse_messages_endpoint(), + )) + .merge(messages_routes::routes( + server_options.sse_messages_endpoint(), )) - .merge(messages_routes::routes(state.clone())) } r }) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index bce49e1..89a3b01 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -12,9 +12,9 @@ use axum::{ }; use std::{collections::HashMap, sync::Arc}; -pub fn routes(state: Arc) -> Router> { +pub fn routes(sse_message_endpoint: &str) -> Router> { Router::new().route( - remove_query_and_hash(&state.sse_message_endpoint).as_str(), + remove_query_and_hash(&sse_message_endpoint).as_str(), post(handle_messages), ) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 220b6f4..9e7e900 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -30,6 +30,9 @@ use tokio_stream::StreamExt; const DUPLEX_BUFFER_SIZE: usize = 8192; +#[derive(Clone)] +pub struct SseMessageEndpoint(pub String); + /// Creates an initial SSE event that returns the messages endpoint /// /// Constructs an SSE event containing the messages endpoint URL with the session ID. @@ -53,9 +56,17 @@ fn initial_event(endpoint: &str) -> Result { /// /// # Returns /// * `Router>` - An Axum router configured with the SSE route -pub fn routes(state: Arc, sse_endpoint: &str) -> Router> { +pub fn routes( + state: Arc, + sse_endpoint: &str, + sse_message_endpoint: &str, +) -> Router> { + let sse_message_endpoint = SseMessageEndpoint(sse_message_endpoint.to_string()); Router::new() - .route(sse_endpoint, get(handle_sse)) + .route( + sse_endpoint, + get(handle_sse).layer(Extension(sse_message_endpoint)), + ) .route_layer(middleware::from_fn_with_state( state.clone(), generate_session_id, @@ -78,10 +89,18 @@ pub fn routes(state: Arc, sse_endpoint: &str) -> Router> /// * `TransportServerResult` - The SSE response stream or an error pub async fn handle_sse( Extension(session_id): Extension, + Extension(sse_message_endpoint): Extension, State(state): State>, ) -> TransportServerResult { + let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; + tracing::warn!( + ">>> session_id {:?}, sse_message_endpoint>>> {:?}", + session_id, + sse_message_endpoint + ); + let messages_endpoint = - SseTransport::::message_endpoint(&state.sse_message_endpoint, &session_id); + SseTransport::::message_endpoint(&sse_message_endpoint, &session_id); // readable stream of string to be used in transport // writing string to read_tx will be received as messages inside the transport and messages will be processed diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 4fbd8ad..7bd1658 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -268,8 +268,6 @@ impl HyperServer { server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, - sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(), - http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(), transport_options: Arc::clone(&server_options.transport_options), enable_json_response: server_options.enable_json_response.unwrap_or(false), allowed_hosts: server_options.allowed_hosts.take(), diff --git a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs index 9553bff..3bf81f7 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs @@ -20,8 +20,6 @@ pub struct AppState { pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, - pub sse_message_endpoint: String, - pub http_streamable_endpoint: String, pub transport_options: Arc, pub enable_json_response: bool, /// List of allowed host header values for DNS rebinding protection. diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index 6ff48ec..b9ef5d6 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -1,7 +1,7 @@ use crate::{ error::SdkResult, hyper_servers::error::{TransportServerError, TransportServerResult}, - mcp_http::AppState, + mcp_http::{AppState, GenericBody}, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, @@ -18,8 +18,10 @@ use axum::{ }, Json, }; +use bytes::Bytes; use futures::stream; -use http::header::ACCEPT; +use http::header::{ACCEPT, CONTENT_TYPE}; +use http_body_util::{BodyExt, Full}; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, @@ -499,3 +501,15 @@ pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { let has_json = types.iter().any(|v| v.starts_with("application/json")); has_event_stream && has_json } + +pub fn error_response( + status_code: StatusCode, + error: SdkError, +) -> Result, http::Error> { + let error_string = serde_json::to_string(&error).unwrap_or_default(); + let body = Full::new(Bytes::from(error_string)).boxed(); + http::Response::builder() + .status(status_code) + .header(CONTENT_TYPE, "application/json") + .body(body) +} From 1616c4c9763a798c4559cb55e988f52c90066f13 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 11 Oct 2025 08:26:18 -0300 Subject: [PATCH 05/14] chore: adjust cagro features --- crates/rust-mcp-sdk/src/lib.rs | 1 + crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 70480c0..2f88673 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -2,6 +2,7 @@ pub mod error; #[cfg(feature = "hyper-server")] mod hyper_servers; mod mcp_handlers; +#[cfg(feature = "hyper-server")] pub(crate) mod mcp_http; mod mcp_macros; mod mcp_runtimes; diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index b9ef5d6..97fcb06 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -1,15 +1,13 @@ +use crate::schema::schema_utils::{ClientMessage, SdkError}; use crate::{ error::SdkResult, hyper_servers::error::{TransportServerError, TransportServerResult}, - mcp_http::{AppState, GenericBody}, + mcp_http::AppState, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, utils::validate_mcp_protocol_version, }; - -use crate::schema::schema_utils::{ClientMessage, SdkError}; - use axum::{http::HeaderValue, response::IntoResponse}; use axum::{ response::{ @@ -21,7 +19,7 @@ use axum::{ use bytes::Bytes; use futures::stream; use http::header::{ACCEPT, CONTENT_TYPE}; -use http_body_util::{BodyExt, Full}; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, @@ -32,6 +30,8 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader}; const DUPLEX_BUFFER_SIZE: usize = 8192; +pub type GenericBody = BoxBody; + async fn create_sse_stream( runtime: Arc, session_id: SessionId, From 000910a7d65ece421e5fbda88482de0af64251d4 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 11 Oct 2025 08:31:47 -0300 Subject: [PATCH 06/14] chore: rename AppState to McpAppState --- .../src/hyper_servers/hyper_runtime.rs | 4 ++-- .../middlewares/protect_dns_rebinding.rs | 4 ++-- .../hyper_servers/middlewares/session_id_gen.rs | 4 ++-- crates/rust-mcp-sdk/src/hyper_servers/routes.rs | 4 ++-- .../src/hyper_servers/routes/messages_routes.rs | 6 +++--- .../src/hyper_servers/routes/sse_routes.rs | 10 +++++----- .../routes/streamable_http_routes.rs | 10 +++++----- crates/rust-mcp-sdk/src/hyper_servers/server.rs | 12 ++++++------ crates/rust-mcp-sdk/src/mcp_http/app_state.rs | 4 ++-- .../rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs | 16 ++++++++-------- 10 files changed, 37 insertions(+), 37 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 59c5b4d..92eed79 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -1,7 +1,7 @@ use std::{sync::Arc, time::Duration}; use crate::{ - mcp_http::AppState, + mcp_http::McpAppState, mcp_server::HyperServer, schema::{ schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, @@ -26,7 +26,7 @@ use crate::{ }; pub struct HyperRuntime { - pub(crate) state: Arc, + pub(crate) state: Arc, pub(crate) server_task: JoinHandle>, pub(crate) server_handle: Handle, } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs index 1fdcfef..3ba8c85 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs @@ -1,4 +1,4 @@ -use crate::{mcp_http::AppState, schema::schema_utils::SdkError}; +use crate::{mcp_http::McpAppState, schema::schema_utils::SdkError}; use axum::{ extract::{Request, State}, middleware::Next, @@ -14,7 +14,7 @@ use std::sync::Arc; // Middleware to protect against DNS rebinding attacks by validating Host and Origin headers. pub async fn protect_dns_rebinding( headers: HeaderMap, - State(state): State>, + State(state): State>, request: Request, next: Next, ) -> impl IntoResponse { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs index 7611b3c..878e3ee 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs @@ -1,4 +1,4 @@ -use crate::mcp_http::AppState; +use crate::mcp_http::McpAppState; use axum::{ extract::{Request, State}, middleware::Next, @@ -10,7 +10,7 @@ use std::sync::Arc; // Middleware to generate and attach a session ID pub async fn generate_session_id( - State(state): State>, + State(state): State>, mut request: Request, next: Next, ) -> Result { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index ae633d8..cd79580 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -3,7 +3,7 @@ pub mod messages_routes; pub mod sse_routes; pub mod streamable_http_routes; -use crate::mcp_http::AppState; +use crate::mcp_http::McpAppState; use super::HyperServerOptions; use axum::Router; @@ -20,7 +20,7 @@ use std::sync::Arc; /// /// # Returns /// * `Router` - An Axum router configured with all application routes and state -pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { +pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { let router: Router = Router::new() .merge(streamable_http_routes::routes( state.clone(), diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 89a3b01..39aa983 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -1,6 +1,6 @@ use crate::{ hyper_servers::error::{TransportServerError, TransportServerResult}, - mcp_http::AppState, + mcp_http::McpAppState, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, utils::remove_query_and_hash, }; @@ -12,7 +12,7 @@ use axum::{ }; use std::{collections::HashMap, sync::Arc}; -pub fn routes(sse_message_endpoint: &str) -> Router> { +pub fn routes(sse_message_endpoint: &str) -> Router> { Router::new().route( remove_query_and_hash(&sse_message_endpoint).as_str(), post(handle_messages), @@ -20,7 +20,7 @@ pub fn routes(sse_message_endpoint: &str) -> Router> { } pub async fn handle_messages( - State(state): State>, + State(state): State>, Query(params): Query>, message: String, ) -> TransportServerResult { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 9e7e900..21bfc09 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,4 +1,4 @@ -use crate::mcp_http::AppState; +use crate::mcp_http::McpAppState; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ @@ -55,12 +55,12 @@ fn initial_event(endpoint: &str) -> Result { /// * `sse_endpoint` - The path for the SSE endpoint /// /// # Returns -/// * `Router>` - An Axum router configured with the SSE route +/// * `Router>` - An Axum router configured with the SSE route pub fn routes( - state: Arc, + state: Arc, sse_endpoint: &str, sse_message_endpoint: &str, -) -> Router> { +) -> Router> { let sse_message_endpoint = SseMessageEndpoint(sse_message_endpoint.to_string()); Router::new() .route( @@ -90,7 +90,7 @@ pub fn routes( pub async fn handle_sse( Extension(session_id): Extension, Extension(sse_message_endpoint): Extension, - State(state): State>, + State(state): State>, ) -> TransportServerResult { let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; tracing::warn!( diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 696204b..b7ec2f3 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -3,7 +3,7 @@ use crate::mcp_http::utils::{ process_incoming_message, process_incoming_message_return, start_new_session, valid_streaming_http_accept_header, validate_mcp_protocol_version_header, }; -use crate::mcp_http::AppState; +use crate::mcp_http::McpAppState; use crate::schema::schema_utils::SdkError; use crate::utils::validate_mcp_protocol_version; use crate::{ @@ -27,7 +27,7 @@ use rust_mcp_transport::{ }; use std::{collections::HashMap, sync::Arc}; -pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { +pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) .route(streamable_http_endpoint, post(handle_streamable_http_post)) @@ -43,7 +43,7 @@ pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router>, + State(state): State>, ) -> TransportServerResult { if !accepts_event_stream(&headers) { let error = SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#); @@ -80,7 +80,7 @@ pub async fn handle_streamable_http_get( pub async fn handle_streamable_http_post( headers: HeaderMap, - State(state): State>, + State(state): State>, Query(_params): Query>, payload: String, ) -> TransportServerResult { @@ -137,7 +137,7 @@ pub async fn handle_streamable_http_post( pub async fn handle_streamable_http_delete( headers: HeaderMap, - State(state): State>, + State(state): State>, ) -> TransportServerResult { if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { let error = diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 7bd1658..bfce062 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,7 +1,7 @@ use crate::{ error::SdkResult, id_generator::{FastIdGenerator, UuidGenerator}, - mcp_http::{AppState, InMemorySessionStore}, + mcp_http::{InMemorySessionStore, McpAppState}, mcp_server::hyper_runtime::HyperRuntime, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; @@ -236,7 +236,7 @@ impl Default for HyperServerOptions { /// Hyper server struct for managing the Axum-based web server pub struct HyperServer { app: Router, - state: Arc, + state: Arc, pub(crate) options: HyperServerOptions, handle: Handle, } @@ -258,7 +258,7 @@ impl HyperServer { handler: Arc, mut server_options: HyperServerOptions, ) -> Self { - let state: Arc = Arc::new(AppState { + let state: Arc = Arc::new(McpAppState { session_store: Arc::new(InMemorySessionStore::new()), id_generator: server_options .session_id_generator @@ -287,8 +287,8 @@ impl HyperServer { /// Returns a shared reference to the application state /// /// # Returns - /// * `Arc` - Shared application state - pub fn state(&self) -> Arc { + /// * `Arc` - Shared application state + pub fn state(&self) -> Arc { Arc::clone(&self.state) } @@ -448,7 +448,7 @@ impl HyperServer { } // Shutdown signal handler -async fn shutdown_signal(handle: Handle, state: Arc) { +async fn shutdown_signal(handle: Handle, state: Arc) { // Wait for a Ctrl+C or SIGTERM signal let ctrl_c = async { signal::ctrl_c() diff --git a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs index 3bf81f7..95ae297 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs @@ -13,7 +13,7 @@ use rust_mcp_transport::{SessionId, TransportOptions}; /// Holds shared, thread-safe references to session storage, ID generator, /// server details, handler, ping interval, and transport options. #[derive(Clone)] -pub struct AppState { +pub struct McpAppState { pub session_store: Arc, pub id_generator: Arc>, pub stream_id_gen: Arc, @@ -36,7 +36,7 @@ pub struct AppState { pub event_store: Option>, } -impl AppState { +impl McpAppState { pub fn needs_dns_protection(&self) -> bool { self.dns_rebinding_protection && (self.allowed_hosts.is_some() || self.allowed_origins.is_some()) diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index 97fcb06..a8b89b1 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -2,7 +2,7 @@ use crate::schema::schema_utils::{ClientMessage, SdkError}; use crate::{ error::SdkResult, hyper_servers::error::{TransportServerError, TransportServerResult}, - mcp_http::AppState, + mcp_http::McpAppState, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, @@ -35,7 +35,7 @@ pub type GenericBody = BoxBody; async fn create_sse_stream( runtime: Arc, session_id: SessionId, - state: Arc, + state: Arc, payload: Option<&str>, standalone: bool, last_event_id: Option, @@ -204,7 +204,7 @@ fn is_result(json_str: &str) -> Result { pub async fn create_standalone_stream( session_id: SessionId, last_event_id: Option, - state: Arc, + state: Arc, ) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( TransportServerError::SessionIdInvalid(session_id.to_string()), @@ -238,7 +238,7 @@ pub async fn create_standalone_stream( } pub async fn start_new_session( - state: Arc, + state: Arc, payload: &str, ) -> TransportServerResult> { let session_id: SessionId = state.id_generator.generate(); @@ -275,7 +275,7 @@ pub async fn start_new_session( async fn single_shot_stream( runtime: Arc, session_id: SessionId, - state: Arc, + state: Arc, payload: Option<&str>, standalone: bool, ) -> TransportServerResult> { @@ -362,7 +362,7 @@ async fn single_shot_stream( pub async fn process_incoming_message_return( session_id: SessionId, - state: Arc, + state: Arc, payload: &str, ) -> TransportServerResult { match state.session_store.get(&session_id).await { @@ -388,7 +388,7 @@ pub async fn process_incoming_message_return( pub async fn process_incoming_message( session_id: SessionId, - state: Arc, + state: Arc, payload: &str, ) -> TransportServerResult { match state.session_store.get(&session_id).await { @@ -437,7 +437,7 @@ pub fn is_empty_sse_message(sse_payload: &str) -> bool { pub async fn delete_session( session_id: SessionId, - state: Arc, + state: Arc, ) -> TransportServerResult { match state.session_store.get(&session_id).await { Some(runtime) => { From edf9af86b8ee0250784d14671b169499ad71d558 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 11 Oct 2025 20:08:18 -0300 Subject: [PATCH 07/14] feat: initial implementation of mcp_http_handler --- .../rust-mcp-sdk/src/hyper_servers/error.rs | 4 +- .../src/mcp_http/mcp_http_handler.rs | 166 ++++++ .../src/mcp_http/mcp_http_utils.rs | 486 +++++++++++++++++- crates/rust-mcp-transport/src/lib.rs | 2 + crates/rust-mcp-transport/src/utils.rs | 2 + .../src/utils/sse_parser.rs | 95 +++- 6 files changed, 749 insertions(+), 6 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/error.rs b/crates/rust-mcp-sdk/src/hyper_servers/error.rs index 74cbcd1..f0590dd 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/error.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/error.rs @@ -1,4 +1,4 @@ -use std::net::AddrParseError; +use std::{net::AddrParseError, path::Display}; use axum::{http::StatusCode, response::IntoResponse}; use thiserror::Error; @@ -15,6 +15,8 @@ pub enum TransportServerError { StreamIoError(String), #[error("{0}")] AddrParseError(#[from] AddrParseError), + #[error("{0}")] + HttpError(String), #[error("Server start error: {0}")] ServerStartError(String), #[error("Invalid options: {0}")] diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 8b13789..36dc3dd 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1 +1,167 @@ +use std::sync::Arc; +use crate::{ + error::McpSdkError, + mcp_http::{ + utils::{ + acceptable_content_type, create_standalone_stream, create_standalone_stream_x, + delete_session, delete_session_x, process_incoming_message_return_x, + process_incoming_message_x, start_new_session_x, valid_streaming_http_accept_header, + GenericBody, + }, + McpAppState, + }, + mcp_server::error::TransportServerResult, + schema::schema_utils::SdkError, + utils::valid_initialize_method, +}; +use axum::response::ErrorResponse; +use bytes::Bytes; +use http::{self, header::CONTENT_TYPE, StatusCode}; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; + +use crate::mcp_http::utils::{ + accepts_event_stream, error_response, validate_mcp_protocol_version_header, +}; + +pub struct McpHttpHandler {} + +impl McpHttpHandler { + pub async fn handle_streamable_http( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let method = request.method(); + match method { + &http::Method::GET => return Self::handle_http_get(request, state).await, + &http::Method::POST => return Self::handle_http_post(request, state).await, + &http::Method::DELETE => return Self::handle_http_delete(request, state).await, + other => { + let error = SdkError::bad_request().with_message(&format!( + "'{other}' is not a valid HTTP method for StreamableHTTP transport." + )); + return error_response(StatusCode::METHOD_NOT_ALLOWED, error); + } + } + } + + async fn handle_http_delete( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(id) => delete_session_x(id, state).await, + None => { + let error = SdkError::bad_request().with_message("Bad Request: Session not found"); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } + + async fn handle_http_post( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if !valid_streaming_http_accept_header(headers) { + let error = SdkError::bad_request() + .with_message(r#"Client must accept both application/json and text/event-stream"#); + return error_response(StatusCode::NOT_ACCEPTABLE, error); + } + + if !acceptable_content_type(headers) { + let error = SdkError::bad_request() + .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#); + return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error); + } + + if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let payload = *request.body(); + + match session_id { + // has session-id => write to the existing stream + Some(id) => { + if state.enable_json_response { + process_incoming_message_return_x(id, state, payload).await + } else { + process_incoming_message_x(id, state, &payload).await + } + } + None => match valid_initialize_method(&payload) { + Ok(_) => { + return start_new_session_x(state, &payload).await; + } + Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error), + Err(error) => { + let error = SdkError::bad_request().with_message(&error.to_string()); + error_response(StatusCode::BAD_REQUEST, error) + } + }, + } + } + + async fn handle_http_get( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if !accepts_event_stream(headers) { + let error = + SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#); + return error_response(StatusCode::NOT_ACCEPTABLE, error); + } + + if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let last_event_id: Option = headers + .get(MCP_LAST_EVENT_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(session_id) => { + let res = create_standalone_stream_x(session_id, last_event_id, state).await; + res + } + None => { + let error = SdkError::bad_request().with_message("Bad request: session not found"); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index a8b89b1..f523215 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -19,18 +19,22 @@ use axum::{ use bytes::Bytes; use futures::stream; use http::header::{ACCEPT, CONTENT_TYPE}; +use http_body::Frame; +use http_body_util::StreamBody; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::{header, HeaderMap, StatusCode}; +use rust_mcp_transport::error::TransportError; use rust_mcp_transport::{ - EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, + EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; +use tokio_stream::StreamExt; const DUPLEX_BUFFER_SIZE: usize = 8192; -pub type GenericBody = BoxBody; +pub type GenericBody = BoxBody; async fn create_sse_stream( runtime: Arc, @@ -505,11 +509,485 @@ pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { pub fn error_response( status_code: StatusCode, error: SdkError, -) -> Result, http::Error> { +) -> TransportServerResult> { let error_string = serde_json::to_string(&error).unwrap_or_default(); - let body = Full::new(Bytes::from(error_string)).boxed(); + let body = Full::new(Bytes::from(error_string)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() .status(status_code) .header(CONTENT_TYPE, "application/json") .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) +} + +// pub fn error_response( +// status_code: StatusCode, +// error: SdkError, +// headers: Option +// ) -> TransportServerResult> { +// let error_string = serde_json::to_string(&error).unwrap_or_default(); +// let body = Full::new(Bytes::from(error_string)) +// .map_err(|err| TransportServerError::HttpError(err.to_string())) +// .boxed(); + +// let mut response = http::Response::builder() +// .status(status_code) +// .header(CONTENT_TYPE, "application/json") +// .body(body) +// .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + +// if let Some(header_map) = headers { +// let response_headers = response.headers_mut(); +// for (key, value) in header_map.into_iter() { +// // Only insert valid headers (Some keys), as `HeaderMap` can contain pseudo-headers +// if let Some(key) = key { +// response_headers.insert(key, value); +// } +// } +// } + +// Ok(response) +// } + +pub async fn create_standalone_stream_x( + session_id: SessionId, + last_event_id: Option, + state: Arc, +) -> TransportServerResult> { + let runtime = state.session_store.get(&session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + )?; + let runtime = runtime.lock().await.to_owned(); + + if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { + let error = + SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); + return error_response(StatusCode::CONFLICT, error) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + } + + if let Some(last_event_id) = last_event_id.as_ref() { + tracing::trace!( + "SSE stream re-connected with last-event-id: {}", + last_event_id + ); + } + + let mut response = create_sse_stream_x( + runtime.clone(), + session_id.clone(), + state.clone(), + None, + true, + last_event_id, + ) + .await?; + *response.status_mut() = StatusCode::OK; + Ok(response) +} + +async fn create_sse_stream_x( + runtime: Arc, + session_id: SessionId, + state: Arc, + payload: Option<&str>, + standalone: bool, + last_event_id: Option, +) -> TransportServerResult> { + let payload_string = payload.map(|p| p.to_string()); + + // TODO: this logic should be moved out after refactoing the mcp_stream.rs + let payload_contains_request = payload_string + .as_ref() + .map(|json_str| contains_request(json_str)) + .unwrap_or(Ok(false)); + let Ok(payload_contains_request) = payload_contains_request else { + return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + }; + + // readable stream of string to be used in transport + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + let session_id = Arc::new(session_id); + let stream_id: Arc = if standalone { + Arc::new(DEFAULT_STREAM_ID.to_string()) + } else { + Arc::new(state.stream_id_gen.generate()) + }; + + let event_store = state.event_store.as_ref().map(Arc::clone); + let resumability_enabled = event_store.is_some(); + + let mut transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + if let Some(event_store) = event_store.clone() { + transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); + } + let transport = Arc::new(transport); + + let ping_interval = state.ping_interval; + let runtime_clone = Arc::clone(&runtime); + let stream_id_clone = stream_id.clone(); + let transport_clone = transport.clone(); + + //Start the server runtime + tokio::spawn(async move { + match runtime_clone + .start_stream( + transport_clone, + &stream_id_clone, + ping_interval, + payload_string, + ) + .await + { + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), + } + let _ = runtime.remove_transport(&stream_id_clone).await; + }); + + // Construct SSE stream + let reader = BufReader::new(write_rx); + + // send outgoing messages from server to the client over the sse stream + let message_stream = stream::unfold(reader, move |mut reader| { + async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + + // empty sse comment to keep-alive + if is_empty_sse_message(&trimmed_line) { + return Some((Ok(SseEvent::default().as_bytes()), reader)); + } + + let (event_id, message) = match ( + resumability_enabled, + trimmed_line.split_once(char::from(ID_SEPARATOR)), + ) { + (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), + _ => (None, trimmed_line), + }; + + let event = match event_id { + Some(id) => SseEvent::default() + .with_data(message) + .with_id(id) + .as_bytes(), + None => SseEvent::default().with_data(message).as_bytes(), + }; + + Some((Ok(event), reader)) + } + Err(e) => Some((Err(e), reader)), + } + } + }); + + let streaming_body: GenericBody = + http_body_util::BodyExt::boxed(StreamBody::new(message_stream.map(|res| { + res.map(Frame::data) + .map_err(|err: std::io::Error| TransportServerError::HttpError(err.to_string())) + }))); + + let session_id_value = HeaderValue::from_str(&session_id) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + let status_code = if !payload_contains_request { + StatusCode::ACCEPTED + } else { + StatusCode::OK + }; + + let response = http::Response::builder() + .status(status_code) + .header("Content-Type", "text/event-stream") + .header(MCP_SESSION_ID_HEADER, session_id_value) + .header("Connection", "keep-alive") + .body(streaming_body) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + // if last_event_id exists we replay messages from the event-store + tokio::spawn(async move { + if let Some(last_event_id) = last_event_id { + if let Some(event_store) = state.event_store.as_ref() { + if let Some(events) = event_store.events_after(last_event_id).await { + for message_payload in events.messages { + // skip storing replay messages + let error = transport.write_str(&message_payload, true).await; + if let Err(error) = error { + tracing::trace!("Error replaying message: {error}") + } + } + } + } + } + }); + + Ok(response) +} + +async fn single_shot_stream_x( + runtime: Arc, + session_id: SessionId, + state: Arc, + payload: Option<&str>, + standalone: bool, +) -> TransportServerResult> { + // readable stream of string to be used in transport + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + let transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + + let stream_id = if standalone { + DEFAULT_STREAM_ID.to_string() + } else { + state.id_generator.generate() + }; + let ping_interval = state.ping_interval; + let runtime_clone = Arc::clone(&runtime); + + let payload_string = payload.map(|p| p.to_string()); + + tokio::spawn(async move { + match runtime_clone + .start_stream( + Arc::new(transport), + &stream_id, + ping_interval, + payload_string, + ) + .await + { + Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + } + let _ = runtime.remove_transport(&stream_id).await; + }); + + let mut reader = BufReader::new(write_rx); + let mut line = String::new(); + let response = match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + Some(Ok(trimmed_line)) + } + Err(e) => Some(Err(e)), + }; + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + headers.insert( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(&session_id).unwrap(), + ); + + match response { + Some(response_result) => match response_result { + Ok(response_str) => { + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); // Uses BodyExt::boxed + + let response = http::Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + + response + } + Err(err) => { + let body = Full::new(Bytes::from(err.to_string())) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + }, + None => { + let body = Full::new(Bytes::from( + "End of the transport stream reached.".to_string(), + )) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::UNPROCESSABLE_ENTITY) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + } +} + +pub async fn process_incoming_message_return_x( + session_id: SessionId, + state: Arc, + payload: &str, +) -> TransportServerResult> { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + + single_shot_stream_x( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await + // Ok(StatusCode::OK.into_response()) + } + None => { + let error = SdkError::session_not_found(); + error_response(StatusCode::NOT_FOUND, error) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + } +} + +pub async fn process_incoming_message_x( + session_id: SessionId, + state: Arc, + payload: &str, +) -> TransportServerResult> { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport + // it should be processed by the same transport , therefore no need to call create_sse_stream + let Ok(is_result) = is_result(payload) else { + return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()); + }; + + if is_result { + match runtime + .consume_payload_string(DEFAULT_STREAM_ID, payload) + .await + { + Ok(()) => { + let body = Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + Err(err) => { + let error = + SdkError::internal_error().with_message(err.to_string().as_ref()); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } else { + create_sse_stream_x( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + None, + ) + .await + } + } + None => { + let error = SdkError::session_not_found(); + error_response(StatusCode::NOT_FOUND, error) + } + } +} + +pub async fn start_new_session_x( + state: Arc, + payload: &str, +) -> TransportServerResult> { + let session_id: SessionId = state.id_generator.generate(); + + let h: Arc = state.handler.clone(); + // create a new server instance with unique session_id and + let runtime: Arc = server_runtime::create_server_instance( + Arc::clone(&state.server_details), + h, + session_id.to_owned(), + ); + + tracing::info!("a new client joined : {}", &session_id); + + let response = create_sse_stream_x( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + None, + ) + .await; + + if response.is_ok() { + state + .session_store + .set(session_id.to_owned(), runtime.clone()) + .await; + } + response +} + +pub async fn delete_session_x( + session_id: SessionId, + state: Arc, +) -> TransportServerResult> { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + runtime.shutdown().await; + state.session_store.delete(&session_id).await; + tracing::info!("client disconnected : {}", &session_id); + + let body = Full::new(Bytes::from("ok")) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + None => { + let error = SdkError::session_not_found(); + error_response(StatusCode::NOT_FOUND, error) + } + } } diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index d21e5dd..3568cc2 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -31,6 +31,8 @@ pub use sse::*; pub use stdio::*; pub use transport::*; +pub use utils::SseEvent; + // Type alias for session identifier, represented as a String pub type SessionId = String; // Type alias for stream identifier (that will be used at the transport scope), represented as a String diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 034f062..813b0ee 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -18,6 +18,8 @@ pub(crate) use http_utils::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] +pub use sse_parser::SseEvent; +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use sse_parser::*; #[cfg(feature = "sse")] pub(crate) use sse_stream::*; diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index 5933726..d8b936a 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -14,18 +14,90 @@ pub struct SseEvent { pub data: Option, /// The optional event ID for reconnection or tracking purposes. pub id: Option, + /// Optional reconnection retry interval (in milliseconds). + pub retry: Option, +} + +impl SseEvent { + /// Creates a new `SseEvent` with the given string data. + pub fn new>(data: T) -> Self { + Self { + event: None, + data: Some(Bytes::from(data.into())), + id: None, + retry: None, + } + } + + /// Sets the event name (e.g., "message"). + pub fn with_event>(mut self, event: T) -> Self { + self.event = Some(event.into()); + self + } + + /// Sets the ID of the event. + pub fn with_id>(mut self, id: T) -> Self { + self.id = Some(id.into()); + self + } + + /// Sets the retry interval (in milliseconds). + pub fn with_retry(mut self, retry: u64) -> Self { + self.retry = Some(retry); + self + } + + /// Sets the data as bytes. + pub fn with_data_bytes(mut self, data: Bytes) -> Self { + self.data = Some(data); + self + } + + /// Sets the data. + pub fn with_data(mut self, data: String) -> Self { + self.data = Some(Bytes::from(data)); + self + } + + /// Converts the event into a string in SSE format (ready for HTTP body). + pub fn to_sse_string(&self) -> String { + self.to_string() + } + + pub fn as_bytes(&self) -> Bytes { + Bytes::from(self.to_string()) + } +} + +impl Default for SseEvent { + fn default() -> Self { + Self { + event: Default::default(), + data: Default::default(), + id: Default::default(), + retry: Default::default(), + } + } } impl std::fmt::Display for SseEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Emit retry interval + if let Some(retry) = self.retry { + writeln!(f, "retry: {retry}")?; + } + + // Emit ID if let Some(id) = &self.id { writeln!(f, "id: {id}")?; } + // Emit event type if let Some(event) = &self.event { writeln!(f, "event: {event}")?; } + // Emit data lines if let Some(data) = &self.data { match std::str::from_utf8(data) { Ok(text) => { @@ -39,7 +111,7 @@ impl std::fmt::Display for SseEvent { } } - writeln!(f)?; // Trailing newline for SSE message end + writeln!(f)?; // Trailing newline for SSE message end, separates events Ok(()) } } @@ -57,6 +129,7 @@ impl fmt::Debug for SseEvent { .field("event", &self.event) .field("data", &data_str) .field("id", &self.id) + .field("retry", &self.retry) .finish() } } @@ -193,11 +266,15 @@ impl SseParser { // Get event (default to None) let event = fields.get("event").cloned(); let id = fields.get("id").cloned(); + let retry = fields + .get("retry") + .and_then(|r| r.trim().parse::().ok()); Some(SseEvent { event, data: Some(data), id, + retry, }) } } @@ -317,4 +394,20 @@ mod tests { Some(Bytes::from("second\n").as_ref()) ); } + + #[test] + fn test_basic_sse_event() { + let mut parser = SseParser::new(); + let input = Bytes::from("event: message\ndata: Hello\nid: 1\nretry: 5000\n\n"); + + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + + let event = &events[0]; + assert_eq!(event.event.as_deref(), Some("message")); + assert_eq!(event.data.as_deref(), Some(Bytes::from("Hello\n").as_ref())); + assert_eq!(event.id.as_deref(), Some("1")); + assert_eq!(event.retry, Some(5000)); + } } From 8b0a52b08c18e567f274be697861581c723ff45f Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 09:30:38 -0300 Subject: [PATCH 08/14] feat: de-couple stramable http logic from the server framework --- .../rust-mcp-sdk/src/hyper_servers/error.rs | 2 +- .../hyper_servers/routes/messages_routes.rs | 2 +- .../src/hyper_servers/routes/sse_routes.rs | 5 - .../routes/streamable_http_routes.rs | 162 ++--- .../src/mcp_http/mcp_http_handler.rs | 32 +- .../src/mcp_http/mcp_http_utils.rs | 618 +++--------------- crates/rust-mcp-transport/src/lib.rs | 1 + .../src/utils/sse_parser.rs | 12 +- 8 files changed, 175 insertions(+), 659 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/error.rs b/crates/rust-mcp-sdk/src/hyper_servers/error.rs index f0590dd..dd55d8f 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/error.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/error.rs @@ -1,4 +1,4 @@ -use std::{net::AddrParseError, path::Display}; +use std::net::AddrParseError; use axum::{http::StatusCode, response::IntoResponse}; use thiserror::Error; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 39aa983..6447b35 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -14,7 +14,7 @@ use std::{collections::HashMap, sync::Arc}; pub fn routes(sse_message_endpoint: &str) -> Router> { Router::new().route( - remove_query_and_hash(&sse_message_endpoint).as_str(), + remove_query_and_hash(sse_message_endpoint).as_str(), post(handle_messages), ) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 21bfc09..1c2910b 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -93,11 +93,6 @@ pub async fn handle_sse( State(state): State>, ) -> TransportServerResult { let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; - tracing::warn!( - ">>> session_id {:?}, sse_message_endpoint>>> {:?}", - session_id, - sse_message_endpoint - ); let messages_endpoint = SseTransport::::message_endpoint(&sse_message_endpoint, &session_id); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index b7ec2f3..b3ed4bd 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,30 +1,17 @@ -use crate::mcp_http::utils::{ - acceptable_content_type, accepts_event_stream, create_standalone_stream, delete_session, - process_incoming_message, process_incoming_message_return, start_new_session, - valid_streaming_http_accept_header, validate_mcp_protocol_version_header, -}; -use crate::mcp_http::McpAppState; -use crate::schema::schema_utils::SdkError; -use crate::utils::validate_mcp_protocol_version; -use crate::{ - error::McpSdkError, - hyper_servers::{ - error::TransportServerResult, middlewares::protect_dns_rebinding::protect_dns_rebinding, - }, - utils::valid_initialize_method, +use crate::hyper_servers::{ + error::TransportServerResult, middlewares::protect_dns_rebinding::protect_dns_rebinding, }; +use crate::mcp_http::{McpAppState, McpHttpHandler}; use axum::routing::get; use axum::{ extract::{Query, State}, middleware, response::IntoResponse, routing::{delete, post}, - Json, Router, -}; -use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::{ - SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, + Router, }; +use http::{Method, Uri}; +use hyper::HeaderMap; use std::{collections::HashMap, sync::Arc}; pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { @@ -43,121 +30,74 @@ pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router pub async fn handle_streamable_http_get( headers: HeaderMap, + uri: Uri, State(state): State>, ) -> TransportServerResult { - if !accepts_event_stream(&headers) { - let error = SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#); - return Ok((StatusCode::NOT_ACCEPTABLE, Json(error)).into_response()); - } - - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { - let error = - SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); + let mut request = http::Request::builder() + .method(Method::GET) + .uri(uri) + .body("") + .unwrap(); //TODO: error handling - let last_event_id: Option = headers - .get(MCP_LAST_EVENT_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - match session_id { - Some(session_id) => { - let res = create_standalone_stream(session_id, last_event_id, state).await?; - Ok(res.into_response()) - } - None => { - let error = SdkError::bad_request().with_message("Bad request: session not found"); - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) + let req_headers = request.headers_mut(); + for (key, value) in headers { + if let Some(k) = key { + req_headers.insert(k, value); } } + + let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } pub async fn handle_streamable_http_post( headers: HeaderMap, + uri: Uri, State(state): State>, Query(_params): Query>, payload: String, ) -> TransportServerResult { - if !valid_streaming_http_accept_header(&headers) { - let error = SdkError::bad_request() - .with_message(r#"Client must accept both application/json and text/event-stream"#); - return Ok((StatusCode::NOT_ACCEPTABLE, Json(error)).into_response()); - } + let mut request = http::Request::builder() + .method(Method::POST) + .uri(uri) + .body(payload.as_str()) + .unwrap(); //TODO: error handling - if !acceptable_content_type(&headers) { - let error = SdkError::bad_request() - .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#); - return Ok((StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(error)).into_response()); - } - - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { - let error = - SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - //TODO: validate reconnect after disconnect - - match session_id { - // has session-id => write to the existing stream - Some(id) => { - if state.enable_json_response { - let res = process_incoming_message_return(id, state, &payload).await?; - Ok(res.into_response()) - } else { - let res = process_incoming_message(id, state, &payload).await?; - Ok(res.into_response()) - } + let req_headers = request.headers_mut(); + for (key, value) in headers { + if let Some(k) = key { + req_headers.insert(k, value); } - None => match valid_initialize_method(&payload) { - Ok(_) => { - return start_new_session(state, &payload).await; - } - Err(McpSdkError::SdkError(error)) => { - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) - } - Err(error) => { - let error = SdkError::bad_request().with_message(&error.to_string()); - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) - } - }, } + + let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } pub async fn handle_streamable_http_delete( headers: HeaderMap, + uri: Uri, State(state): State>, ) -> TransportServerResult { - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { - let error = - SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); + let mut request = http::Request::builder() + .method(Method::DELETE) + .uri(uri) + .body("") + .unwrap(); //TODO: error handling - match session_id { - Some(id) => { - let res = delete_session(id, state).await; - Ok(res.into_response()) - } - None => { - let error = SdkError::bad_request().with_message("Bad Request: Session not found"); - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) + let req_headers = request.headers_mut(); + for (key, value) in headers { + if let Some(k) = key { + req_headers.insert(k, value); } } + + let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 36dc3dd..762d3ce 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -4,10 +4,9 @@ use crate::{ error::McpSdkError, mcp_http::{ utils::{ - acceptable_content_type, create_standalone_stream, create_standalone_stream_x, - delete_session, delete_session_x, process_incoming_message_return_x, - process_incoming_message_x, start_new_session_x, valid_streaming_http_accept_header, - GenericBody, + acceptable_content_type, create_standalone_stream, delete_session, + process_incoming_message, process_incoming_message_return, start_new_session, + valid_streaming_http_accept_header, GenericBody, }, McpAppState, }, @@ -15,10 +14,7 @@ use crate::{ schema::schema_utils::SdkError, utils::valid_initialize_method, }; -use axum::response::ErrorResponse; -use bytes::Bytes; -use http::{self, header::CONTENT_TYPE, StatusCode}; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use http::{self, StatusCode}; use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use crate::mcp_http::utils::{ @@ -41,7 +37,7 @@ impl McpHttpHandler { let error = SdkError::bad_request().with_message(&format!( "'{other}' is not a valid HTTP method for StreamableHTTP transport." )); - return error_response(StatusCode::METHOD_NOT_ALLOWED, error); + error_response(StatusCode::METHOD_NOT_ALLOWED, error) } } } @@ -52,7 +48,7 @@ impl McpHttpHandler { ) -> TransportServerResult> { let headers = request.headers(); - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { let error = SdkError::bad_request() .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); return error_response(StatusCode::BAD_REQUEST, error); @@ -64,7 +60,7 @@ impl McpHttpHandler { .map(|s| s.to_string()); match session_id { - Some(id) => delete_session_x(id, state).await, + Some(id) => delete_session(id, state).await, None => { let error = SdkError::bad_request().with_message("Bad Request: Session not found"); error_response(StatusCode::BAD_REQUEST, error) @@ -90,7 +86,7 @@ impl McpHttpHandler { return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error); } - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { let error = SdkError::bad_request() .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); return error_response(StatusCode::BAD_REQUEST, error); @@ -107,14 +103,14 @@ impl McpHttpHandler { // has session-id => write to the existing stream Some(id) => { if state.enable_json_response { - process_incoming_message_return_x(id, state, payload).await + process_incoming_message_return(id, state, payload).await } else { - process_incoming_message_x(id, state, &payload).await + process_incoming_message(id, state, payload).await } } - None => match valid_initialize_method(&payload) { + None => match valid_initialize_method(payload) { Ok(_) => { - return start_new_session_x(state, &payload).await; + return start_new_session(state, payload).await; } Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error), Err(error) => { @@ -137,7 +133,7 @@ impl McpHttpHandler { return error_response(StatusCode::NOT_ACCEPTABLE, error); } - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { let error = SdkError::bad_request() .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); return error_response(StatusCode::BAD_REQUEST, error); @@ -155,7 +151,7 @@ impl McpHttpHandler { match session_id { Some(session_id) => { - let res = create_standalone_stream_x(session_id, last_event_id, state).await; + let res = create_standalone_stream(session_id, last_event_id, state).await; res } None => { diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index f523215..14b8ad7 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -8,27 +8,19 @@ use crate::{ mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, utils::validate_mcp_protocol_version, }; -use axum::{http::HeaderValue, response::IntoResponse}; -use axum::{ - response::{ - sse::{Event, KeepAlive}, - Sse, - }, - Json, -}; +use axum::http::HeaderValue; use bytes::Bytes; use futures::stream; -use http::header::{ACCEPT, CONTENT_TYPE}; +use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE}; use http_body::Frame; use http_body_util::StreamBody; use http_body_util::{combinators::BoxBody, BodyExt, Full}; -use hyper::{header, HeaderMap, StatusCode}; -use rust_mcp_transport::error::TransportError; +use hyper::{HeaderMap, StatusCode}; use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio_stream::StreamExt; @@ -43,7 +35,7 @@ async fn create_sse_stream( payload: Option<&str>, standalone: bool, last_event_id: Option, -) -> TransportServerResult> { +) -> TransportServerResult> { let payload_string = payload.map(|p| p.to_string()); // TODO: this logic should be moved out after refactoing the mcp_stream.rs @@ -52,7 +44,7 @@ async fn create_sse_stream( .map(|json_str| contains_request(json_str)) .unwrap_or(Ok(false)); let Ok(payload_contains_request) = payload_contains_request else { - return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()); }; // readable stream of string to be used in transport @@ -119,7 +111,7 @@ async fn create_sse_stream( // empty sse comment to keep-alive if is_empty_sse_message(&trimmed_line) { - return Some((Ok(Event::default()), reader)); + return Some((Ok(SseEvent::default().as_bytes()), reader)); } let (event_id, message) = match ( @@ -131,8 +123,11 @@ async fn create_sse_stream( }; let event = match event_id { - Some(id) => Event::default().data(message).id(id), - None => Event::default().data(message), + Some(id) => SseEvent::default() + .with_data(message) + .with_id(id) + .as_bytes(), + None => SseEvent::default().with_data(message).as_bytes(), }; Some((Ok(event), reader)) @@ -142,16 +137,29 @@ async fn create_sse_stream( } }); - let sse_stream = - Sse::new(message_stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); + // create a stream body + let streaming_body: GenericBody = + http_body_util::BodyExt::boxed(StreamBody::new(message_stream.map(|res| { + res.map(Frame::data) + .map_err(|err: std::io::Error| TransportServerError::HttpError(err.to_string())) + }))); - // Return SSE response with keep-alive - // Create a Response and set headers - let mut response = sse_stream.into_response(); - response.headers_mut().insert( - MCP_SESSION_ID_HEADER, - HeaderValue::from_str(&session_id).unwrap(), - ); + let session_id_value = HeaderValue::from_str(&session_id) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + let status_code = if !payload_contains_request { + StatusCode::ACCEPTED + } else { + StatusCode::OK + }; + + let response = http::Response::builder() + .status(status_code) + .header(CONTENT_TYPE, "text/event-stream") + .header(MCP_SESSION_ID_HEADER, session_id_value) + .header(CONNECTION, "keep-alive") + .body(streaming_body) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; // if last_event_id exists we replay messages from the event-store tokio::spawn(async move { @@ -170,9 +178,6 @@ async fn create_sse_stream( } }); - if !payload_contains_request { - *response.status_mut() = StatusCode::ACCEPTED; - } Ok(response) } @@ -209,7 +214,7 @@ pub async fn create_standalone_stream( session_id: SessionId, last_event_id: Option, state: Arc, -) -> TransportServerResult> { +) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( TransportServerError::SessionIdInvalid(session_id.to_string()), )?; @@ -218,7 +223,8 @@ pub async fn create_standalone_stream( if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { let error = SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); - return Ok((StatusCode::CONFLICT, Json(error)).into_response()); + return error_response(StatusCode::CONFLICT, error) + .map_err(|err| TransportServerError::HttpError(err.to_string())); } if let Some(last_event_id) = last_event_id.as_ref() { @@ -244,7 +250,7 @@ pub async fn create_standalone_stream( pub async fn start_new_session( state: Arc, payload: &str, -) -> TransportServerResult> { +) -> TransportServerResult> { let session_id: SessionId = state.id_generator.generate(); let h: Arc = state.handler.clone(); @@ -275,14 +281,13 @@ pub async fn start_new_session( } response } - async fn single_shot_stream( runtime: Arc, session_id: SessionId, state: Arc, payload: Option<&str>, standalone: bool, -) -> TransportServerResult> { +) -> TransportServerResult> { // readable stream of string to be used in transport let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); // writable stream to deliver message to the client @@ -333,34 +338,46 @@ async fn single_shot_stream( Err(e) => Some(Err(e)), }; - let mut headers = HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - headers.insert( - MCP_SESSION_ID_HEADER, - HeaderValue::from_str(&session_id).unwrap(), - ); + let session_id_value = HeaderValue::from_str(&session_id) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; match response { Some(response_result) => match response_result { Ok(response_str) => { - Ok((StatusCode::OK, headers, response_str.to_string()).into_response()) + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .header(MCP_SESSION_ID_HEADER, session_id_value) + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + Err(err) => { + let body = Full::new(Bytes::from(err.to_string())) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) } - Err(err) => Ok(( - StatusCode::INTERNAL_SERVER_ERROR, - headers, - Json(err.to_string()), - ) - .into_response()), }, - None => Ok(( - StatusCode::UNPROCESSABLE_ENTITY, - headers, - Json("End of the transport stream reached."), - ) - .into_response()), + None => { + let body = Full::new(Bytes::from( + "End of the transport stream reached.".to_string(), + )) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::UNPROCESSABLE_ENTITY) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } } } @@ -368,14 +385,14 @@ pub async fn process_incoming_message_return( session_id: SessionId, state: Arc, payload: &str, -) -> TransportServerResult { +) -> TransportServerResult> { match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); single_shot_stream( runtime.clone(), - session_id.clone(), + session_id, state.clone(), Some(payload), false, @@ -385,7 +402,8 @@ pub async fn process_incoming_message_return( } None => { let error = SdkError::session_not_found(); - Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) + error_response(StatusCode::NOT_FOUND, error) + .map_err(|err| TransportServerError::HttpError(err.to_string())) } } } @@ -394,14 +412,14 @@ pub async fn process_incoming_message( session_id: SessionId, state: Arc, payload: &str, -) -> TransportServerResult { +) -> TransportServerResult> { match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport // it should be processed by the same transport , therefore no need to call create_sse_stream let Ok(is_result) = is_result(payload) else { - return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()); }; if is_result { @@ -409,12 +427,21 @@ pub async fn process_incoming_message( .consume_payload_string(DEFAULT_STREAM_ID, payload) .await { - Ok(()) => Ok((StatusCode::ACCEPTED, Json(())).into_response()), - Err(err) => Ok(( - StatusCode::BAD_REQUEST, - Json(SdkError::internal_error().with_message(err.to_string().as_ref())), - ) - .into_response()), + Ok(()) => { + let body = Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + Err(err) => { + let error = + SdkError::internal_error().with_message(err.to_string().as_ref()); + error_response(StatusCode::BAD_REQUEST, error) + } } } else { create_sse_stream( @@ -430,7 +457,7 @@ pub async fn process_incoming_message( } None => { let error = SdkError::session_not_found(); - Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) + error_response(StatusCode::NOT_FOUND, error) } } } @@ -442,18 +469,26 @@ pub fn is_empty_sse_message(sse_payload: &str) -> bool { pub async fn delete_session( session_id: SessionId, state: Arc, -) -> TransportServerResult { +) -> TransportServerResult> { match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); runtime.shutdown().await; state.session_store.delete(&session_id).await; tracing::info!("client disconnected : {}", &session_id); - Ok((StatusCode::OK, Json("ok")).into_response()) + + let body = Full::new(Bytes::from("ok")) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) } None => { let error = SdkError::session_not_found(); - Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) + error_response(StatusCode::NOT_FOUND, error) } } } @@ -550,444 +585,3 @@ pub fn error_response( // Ok(response) // } - -pub async fn create_standalone_stream_x( - session_id: SessionId, - last_event_id: Option, - state: Arc, -) -> TransportServerResult> { - let runtime = state.session_store.get(&session_id).await.ok_or( - TransportServerError::SessionIdInvalid(session_id.to_string()), - )?; - let runtime = runtime.lock().await.to_owned(); - - if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { - let error = - SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); - return error_response(StatusCode::CONFLICT, error) - .map_err(|err| TransportServerError::HttpError(err.to_string())); - } - - if let Some(last_event_id) = last_event_id.as_ref() { - tracing::trace!( - "SSE stream re-connected with last-event-id: {}", - last_event_id - ); - } - - let mut response = create_sse_stream_x( - runtime.clone(), - session_id.clone(), - state.clone(), - None, - true, - last_event_id, - ) - .await?; - *response.status_mut() = StatusCode::OK; - Ok(response) -} - -async fn create_sse_stream_x( - runtime: Arc, - session_id: SessionId, - state: Arc, - payload: Option<&str>, - standalone: bool, - last_event_id: Option, -) -> TransportServerResult> { - let payload_string = payload.map(|p| p.to_string()); - - // TODO: this logic should be moved out after refactoing the mcp_stream.rs - let payload_contains_request = payload_string - .as_ref() - .map(|json_str| contains_request(json_str)) - .unwrap_or(Ok(false)); - let Ok(payload_contains_request) = payload_contains_request else { - return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()) - .map_err(|err| TransportServerError::HttpError(err.to_string())); - }; - - // readable stream of string to be used in transport - let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); - // writable stream to deliver message to the client - let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - - let session_id = Arc::new(session_id); - let stream_id: Arc = if standalone { - Arc::new(DEFAULT_STREAM_ID.to_string()) - } else { - Arc::new(state.stream_id_gen.generate()) - }; - - let event_store = state.event_store.as_ref().map(Arc::clone); - let resumability_enabled = event_store.is_some(); - - let mut transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; - if let Some(event_store) = event_store.clone() { - transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); - } - let transport = Arc::new(transport); - - let ping_interval = state.ping_interval; - let runtime_clone = Arc::clone(&runtime); - let stream_id_clone = stream_id.clone(); - let transport_clone = transport.clone(); - - //Start the server runtime - tokio::spawn(async move { - match runtime_clone - .start_stream( - transport_clone, - &stream_id_clone, - ping_interval, - payload_string, - ) - .await - { - Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), - } - let _ = runtime.remove_transport(&stream_id_clone).await; - }); - - // Construct SSE stream - let reader = BufReader::new(write_rx); - - // send outgoing messages from server to the client over the sse stream - let message_stream = stream::unfold(reader, move |mut reader| { - async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - - // empty sse comment to keep-alive - if is_empty_sse_message(&trimmed_line) { - return Some((Ok(SseEvent::default().as_bytes()), reader)); - } - - let (event_id, message) = match ( - resumability_enabled, - trimmed_line.split_once(char::from(ID_SEPARATOR)), - ) { - (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), - _ => (None, trimmed_line), - }; - - let event = match event_id { - Some(id) => SseEvent::default() - .with_data(message) - .with_id(id) - .as_bytes(), - None => SseEvent::default().with_data(message).as_bytes(), - }; - - Some((Ok(event), reader)) - } - Err(e) => Some((Err(e), reader)), - } - } - }); - - let streaming_body: GenericBody = - http_body_util::BodyExt::boxed(StreamBody::new(message_stream.map(|res| { - res.map(Frame::data) - .map_err(|err: std::io::Error| TransportServerError::HttpError(err.to_string())) - }))); - - let session_id_value = HeaderValue::from_str(&session_id) - .map_err(|err| TransportServerError::HttpError(err.to_string()))?; - - let status_code = if !payload_contains_request { - StatusCode::ACCEPTED - } else { - StatusCode::OK - }; - - let response = http::Response::builder() - .status(status_code) - .header("Content-Type", "text/event-stream") - .header(MCP_SESSION_ID_HEADER, session_id_value) - .header("Connection", "keep-alive") - .body(streaming_body) - .map_err(|err| TransportServerError::HttpError(err.to_string()))?; - - // if last_event_id exists we replay messages from the event-store - tokio::spawn(async move { - if let Some(last_event_id) = last_event_id { - if let Some(event_store) = state.event_store.as_ref() { - if let Some(events) = event_store.events_after(last_event_id).await { - for message_payload in events.messages { - // skip storing replay messages - let error = transport.write_str(&message_payload, true).await; - if let Err(error) = error { - tracing::trace!("Error replaying message: {error}") - } - } - } - } - } - }); - - Ok(response) -} - -async fn single_shot_stream_x( - runtime: Arc, - session_id: SessionId, - state: Arc, - payload: Option<&str>, - standalone: bool, -) -> TransportServerResult> { - // readable stream of string to be used in transport - let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); - // writable stream to deliver message to the client - let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - - let transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; - - let stream_id = if standalone { - DEFAULT_STREAM_ID.to_string() - } else { - state.id_generator.generate() - }; - let ping_interval = state.ping_interval; - let runtime_clone = Arc::clone(&runtime); - - let payload_string = payload.map(|p| p.to_string()); - - tokio::spawn(async move { - match runtime_clone - .start_stream( - Arc::new(transport), - &stream_id, - ping_interval, - payload_string, - ) - .await - { - Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), - } - let _ = runtime.remove_transport(&stream_id).await; - }); - - let mut reader = BufReader::new(write_rx); - let mut line = String::new(); - let response = match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some(Ok(trimmed_line)) - } - Err(e) => Some(Err(e)), - }; - - let mut headers = HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - headers.insert( - MCP_SESSION_ID_HEADER, - HeaderValue::from_str(&session_id).unwrap(), - ); - - match response { - Some(response_result) => match response_result { - Ok(response_str) => { - let body = Full::new(Bytes::from(response_str)) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); // Uses BodyExt::boxed - - let response = http::Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body) - .map_err(|err| TransportServerError::HttpError(err.to_string())); - - response - } - Err(err) => { - let body = Full::new(Bytes::from(err.to_string())) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); - http::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header(CONTENT_TYPE, "application/json") - .body(body) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - } - }, - None => { - let body = Full::new(Bytes::from( - "End of the transport stream reached.".to_string(), - )) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); - http::Response::builder() - .status(StatusCode::UNPROCESSABLE_ENTITY) - .header(CONTENT_TYPE, "application/json") - .body(body) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - } - } -} - -pub async fn process_incoming_message_return_x( - session_id: SessionId, - state: Arc, - payload: &str, -) -> TransportServerResult> { - match state.session_store.get(&session_id).await { - Some(runtime) => { - let runtime = runtime.lock().await.to_owned(); - - single_shot_stream_x( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - ) - .await - // Ok(StatusCode::OK.into_response()) - } - None => { - let error = SdkError::session_not_found(); - error_response(StatusCode::NOT_FOUND, error) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - } - } -} - -pub async fn process_incoming_message_x( - session_id: SessionId, - state: Arc, - payload: &str, -) -> TransportServerResult> { - match state.session_store.get(&session_id).await { - Some(runtime) => { - let runtime = runtime.lock().await.to_owned(); - // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport - // it should be processed by the same transport , therefore no need to call create_sse_stream - let Ok(is_result) = is_result(payload) else { - return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()); - }; - - if is_result { - match runtime - .consume_payload_string(DEFAULT_STREAM_ID, payload) - .await - { - Ok(()) => { - let body = Full::new(Bytes::new()) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); - http::Response::builder() - .status(200) - .header("Content-Type", "application/json") - .body(body) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - } - Err(err) => { - let error = - SdkError::internal_error().with_message(err.to_string().as_ref()); - error_response(StatusCode::BAD_REQUEST, error) - } - } - } else { - create_sse_stream_x( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - None, - ) - .await - } - } - None => { - let error = SdkError::session_not_found(); - error_response(StatusCode::NOT_FOUND, error) - } - } -} - -pub async fn start_new_session_x( - state: Arc, - payload: &str, -) -> TransportServerResult> { - let session_id: SessionId = state.id_generator.generate(); - - let h: Arc = state.handler.clone(); - // create a new server instance with unique session_id and - let runtime: Arc = server_runtime::create_server_instance( - Arc::clone(&state.server_details), - h, - session_id.to_owned(), - ); - - tracing::info!("a new client joined : {}", &session_id); - - let response = create_sse_stream_x( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - None, - ) - .await; - - if response.is_ok() { - state - .session_store - .set(session_id.to_owned(), runtime.clone()) - .await; - } - response -} - -pub async fn delete_session_x( - session_id: SessionId, - state: Arc, -) -> TransportServerResult> { - match state.session_store.get(&session_id).await { - Some(runtime) => { - let runtime = runtime.lock().await.to_owned(); - runtime.shutdown().await; - state.session_store.delete(&session_id).await; - tracing::info!("client disconnected : {}", &session_id); - - let body = Full::new(Bytes::from("ok")) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); - http::Response::builder() - .status(200) - .header("Content-Type", "application/json") - .body(body) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - } - None => { - let error = SdkError::session_not_found(); - error_response(StatusCode::NOT_FOUND, error) - } - } -} diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 3568cc2..7566290 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -31,6 +31,7 @@ pub use sse::*; pub use stdio::*; pub use transport::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub use utils::SseEvent; // Type alias for session identifier, represented as a String diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index d8b936a..9bd8811 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -7,6 +7,7 @@ const BUFFER_CAPACITY: usize = 1024; /// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. /// /// Contains the event type, data payload, and optional event ID. +#[derive(Clone, Default)] pub struct SseEvent { /// The optional event type (e.g., "message"). pub event: Option, @@ -69,17 +70,6 @@ impl SseEvent { } } -impl Default for SseEvent { - fn default() -> Self { - Self { - event: Default::default(), - data: Default::default(), - id: Default::default(), - retry: Default::default(), - } - } -} - impl std::fmt::Display for SseEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Emit retry interval From 8a16d7584d9ca3dab0a6d0ebc930b12508ea5dc8 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 11:27:02 -0300 Subject: [PATCH 09/14] chore: move sse event out to own module --- crates/rust-mcp-transport/src/utils.rs | 35 +++-- .../rust-mcp-transport/src/utils/sse_event.rs | 122 +++++++++++++++++ .../src/utils/sse_parser.rs | 124 +----------------- 3 files changed, 149 insertions(+), 132 deletions(-) create mode 100644 crates/rust-mcp-transport/src/utils/sse_event.rs diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 813b0ee..36977a2 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,42 +1,57 @@ mod cancellation_token; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod http_utils; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod readable_channel; + +#[cfg(any(feature = "sse", feature = "streamable-http"))] +mod sse_event; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod sse_parser; + #[cfg(feature = "sse")] mod sse_stream; + #[cfg(feature = "streamable-http")] mod streamable_http_stream; + +mod time_utils; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; +use crate::error::{TransportError, TransportResult}; +use crate::schema::schema_utils::SdkError; pub(crate) use cancellation_token::*; + +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use crate::SessionId; + #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; + #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; + #[cfg(any(feature = "sse", feature = "streamable-http"))] -pub use sse_parser::SseEvent; +pub use sse_event::*; + #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use sse_parser::*; + #[cfg(feature = "sse")] pub(crate) use sse_stream::*; + #[cfg(feature = "streamable-http")] pub(crate) use streamable_http_stream::*; -#[cfg(any(feature = "sse", feature = "streamable-http"))] -pub(crate) use writable_channel::*; -mod time_utils; -pub use time_utils::*; -use crate::schema::schema_utils::SdkError; +pub use time_utils::*; use tokio::time::{timeout, Duration}; - -use crate::error::{TransportError, TransportResult}; - #[cfg(any(feature = "sse", feature = "streamable-http"))] -use crate::SessionId; +pub(crate) use writable_channel::*; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where diff --git a/crates/rust-mcp-transport/src/utils/sse_event.rs b/crates/rust-mcp-transport/src/utils/sse_event.rs new file mode 100644 index 0000000..5837807 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/sse_event.rs @@ -0,0 +1,122 @@ +use bytes::Bytes; +use core::fmt; + +/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. +/// +/// Contains the event type, data payload, and optional event ID. +#[derive(Clone, Default)] +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, + /// Optional reconnection retry interval (in milliseconds). + pub retry: Option, +} + +impl SseEvent { + /// Creates a new `SseEvent` with the given string data. + pub fn new>(data: T) -> Self { + Self { + event: None, + data: Some(Bytes::from(data.into())), + id: None, + retry: None, + } + } + + /// Sets the event name (e.g., "message"). + pub fn with_event>(mut self, event: T) -> Self { + self.event = Some(event.into()); + self + } + + /// Sets the ID of the event. + pub fn with_id>(mut self, id: T) -> Self { + self.id = Some(id.into()); + self + } + + /// Sets the retry interval (in milliseconds). + pub fn with_retry(mut self, retry: u64) -> Self { + self.retry = Some(retry); + self + } + + /// Sets the data as bytes. + pub fn with_data_bytes(mut self, data: Bytes) -> Self { + self.data = Some(data); + self + } + + /// Sets the data. + pub fn with_data(mut self, data: String) -> Self { + self.data = Some(Bytes::from(data)); + self + } + + /// Converts the event into a string in SSE format (ready for HTTP body). + pub fn to_sse_string(&self) -> String { + self.to_string() + } + + pub fn as_bytes(&self) -> Bytes { + Bytes::from(self.to_string()) + } +} + +impl std::fmt::Display for SseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Emit retry interval + if let Some(retry) = self.retry { + writeln!(f, "retry: {retry}")?; + } + + // Emit ID + if let Some(id) = &self.id { + writeln!(f, "id: {id}")?; + } + + // Emit event type + if let Some(event) = &self.event { + writeln!(f, "event: {event}")?; + } + + // Emit data lines + if let Some(data) = &self.data { + match std::str::from_utf8(data) { + Ok(text) => { + for line in text.lines() { + writeln!(f, "data: {line}")?; + } + } + Err(_) => { + writeln!(f, "data: [binary data]")?; + } + } + } + + writeln!(f)?; // Trailing newline for SSE message end, separates events + Ok(()) + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self + .data + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .field("retry", &self.retry) + .finish() + } +} diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index 9bd8811..3074e9f 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -1,129 +1,9 @@ -use core::fmt; +use bytes::{Bytes, BytesMut}; use std::collections::HashMap; -use bytes::{Bytes, BytesMut}; +use super::SseEvent; const BUFFER_CAPACITY: usize = 1024; -/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. -/// -/// Contains the event type, data payload, and optional event ID. -#[derive(Clone, Default)] -pub struct SseEvent { - /// The optional event type (e.g., "message"). - pub event: Option, - /// The optional data payload of the event, stored as bytes. - pub data: Option, - /// The optional event ID for reconnection or tracking purposes. - pub id: Option, - /// Optional reconnection retry interval (in milliseconds). - pub retry: Option, -} - -impl SseEvent { - /// Creates a new `SseEvent` with the given string data. - pub fn new>(data: T) -> Self { - Self { - event: None, - data: Some(Bytes::from(data.into())), - id: None, - retry: None, - } - } - - /// Sets the event name (e.g., "message"). - pub fn with_event>(mut self, event: T) -> Self { - self.event = Some(event.into()); - self - } - - /// Sets the ID of the event. - pub fn with_id>(mut self, id: T) -> Self { - self.id = Some(id.into()); - self - } - - /// Sets the retry interval (in milliseconds). - pub fn with_retry(mut self, retry: u64) -> Self { - self.retry = Some(retry); - self - } - - /// Sets the data as bytes. - pub fn with_data_bytes(mut self, data: Bytes) -> Self { - self.data = Some(data); - self - } - - /// Sets the data. - pub fn with_data(mut self, data: String) -> Self { - self.data = Some(Bytes::from(data)); - self - } - - /// Converts the event into a string in SSE format (ready for HTTP body). - pub fn to_sse_string(&self) -> String { - self.to_string() - } - - pub fn as_bytes(&self) -> Bytes { - Bytes::from(self.to_string()) - } -} - -impl std::fmt::Display for SseEvent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Emit retry interval - if let Some(retry) = self.retry { - writeln!(f, "retry: {retry}")?; - } - - // Emit ID - if let Some(id) = &self.id { - writeln!(f, "id: {id}")?; - } - - // Emit event type - if let Some(event) = &self.event { - writeln!(f, "event: {event}")?; - } - - // Emit data lines - if let Some(data) = &self.data { - match std::str::from_utf8(data) { - Ok(text) => { - for line in text.lines() { - writeln!(f, "data: {line}")?; - } - } - Err(_) => { - writeln!(f, "data: [binary data]")?; - } - } - } - - writeln!(f)?; // Trailing newline for SSE message end, separates events - Ok(()) - } -} - -impl fmt::Debug for SseEvent { - /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string - /// (with lossy conversion if invalid UTF-8 is encountered). - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let data_str = self - .data - .as_ref() - .map(|b| String::from_utf8_lossy(b).to_string()); - - f.debug_struct("SseEvent") - .field("event", &self.event) - .field("data", &data_str) - .field("id", &self.id) - .field("retry", &self.retry) - .finish() - } -} - /// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. /// This Parser is specifically designed for MCP messages and with no multi-line data support /// From 0682d150aaf5eff89576d37016cf4fc9070af8f4 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 11:59:24 -0300 Subject: [PATCH 10/14] chore: move dns protection out of middlewares --- .../src/hyper_servers/middlewares.rs | 1 - .../middlewares/protect_dns_rebinding.rs | 65 ---------------- .../src/hyper_servers/routes/sse_routes.rs | 9 +-- .../routes/streamable_http_routes.rs | 11 +-- .../src/mcp_http/mcp_http_handler.rs | 69 ++++++++++------- .../src/mcp_http/mcp_http_utils.rs | 76 ++++++++++++------- 6 files changed, 90 insertions(+), 141 deletions(-) delete mode 100644 crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs index 0222952..612510e 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs @@ -1,2 +1 @@ -pub(crate) mod protect_dns_rebinding; pub(crate) mod session_id_gen; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs deleted file mode 100644 index 3ba8c85..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::{mcp_http::McpAppState, schema::schema_utils::SdkError}; -use axum::{ - extract::{Request, State}, - middleware::Next, - response::IntoResponse, - Json, -}; -use hyper::{ - header::{HOST, ORIGIN}, - HeaderMap, StatusCode, -}; -use std::sync::Arc; - -// Middleware to protect against DNS rebinding attacks by validating Host and Origin headers. -pub async fn protect_dns_rebinding( - headers: HeaderMap, - State(state): State>, - request: Request, - next: Next, -) -> impl IntoResponse { - if !state.needs_dns_protection() { - // If protection is not needed, pass the request to the next handler - return next.run(request).await.into_response(); - } - - if let Some(allowed_hosts) = state.allowed_hosts.as_ref() { - if !allowed_hosts.is_empty() { - let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { - let error = SdkError::bad_request().with_message("Invalid Host header: [unknown] "); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - }; - - if !allowed_hosts - .iter() - .any(|allowed| allowed.eq_ignore_ascii_case(host)) - { - let error = SdkError::bad_request() - .with_message(format!("Invalid Host header: \"{host}\" ").as_str()); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - } - } - } - - if let Some(allowed_origins) = state.allowed_origins.as_ref() { - if !allowed_origins.is_empty() { - let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { - let error = - SdkError::bad_request().with_message("Invalid Origin header: [unknown] "); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - }; - - if !allowed_origins - .iter() - .any(|allowed| allowed.eq_ignore_ascii_case(origin)) - { - let error = SdkError::bad_request() - .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str()); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - } - } - } - - // If all checks pass, proceed to the next handler in the chain - next.run(request).await -} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 1c2910b..1548aea 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -3,10 +3,7 @@ use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ hyper_servers::{ - error::TransportServerResult, - middlewares::{ - protect_dns_rebinding::protect_dns_rebinding, session_id_gen::generate_session_id, - }, + error::TransportServerResult, middlewares::session_id_gen::generate_session_id, }, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, @@ -71,10 +68,6 @@ pub fn routes( state.clone(), generate_session_id, )) - .route_layer(middleware::from_fn_with_state( - state.clone(), - protect_dns_rebinding, - )) } /// Handles Server-Sent Events (SSE) connections diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index b3ed4bd..1610dc6 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,6 +1,4 @@ -use crate::hyper_servers::{ - error::TransportServerResult, middlewares::protect_dns_rebinding::protect_dns_rebinding, -}; +use crate::hyper_servers::error::TransportServerResult; use crate::mcp_http::{McpAppState, McpHttpHandler}; use axum::routing::get; use axum::{ @@ -10,8 +8,7 @@ use axum::{ routing::{delete, post}, Router, }; -use http::{Method, Uri}; -use hyper::HeaderMap; +use http::{HeaderMap, Method, Uri}; use std::{collections::HashMap, sync::Arc}; pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { @@ -22,10 +19,6 @@ pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router streamable_http_endpoint, delete(handle_streamable_http_delete), ) - .route_layer(middleware::from_fn_with_state( - state.clone(), - protect_dns_rebinding, - )) } pub async fn handle_streamable_http_get( diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 762d3ce..8248afc 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,12 +1,12 @@ -use std::sync::Arc; +use std::{os::macos::raw::stat, sync::Arc}; use crate::{ error::McpSdkError, mcp_http::{ utils::{ acceptable_content_type, create_standalone_stream, delete_session, - process_incoming_message, process_incoming_message_return, start_new_session, - valid_streaming_http_accept_header, GenericBody, + process_incoming_message, process_incoming_message_return, protect_dns_rebinding, + start_new_session, valid_streaming_http_accept_header, GenericBody, }, McpAppState, }, @@ -28,6 +28,14 @@ impl McpHttpHandler { request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { + // Enforces DNS rebinding protection if required by state. + // If protection fails, respond with HTTP 403 Forbidden. + if state.needs_dns_protection() { + if let Err(error) = protect_dns_rebinding(&request.headers(), state.clone()).await { + return error_response(StatusCode::FORBIDDEN, error); + } + } + let method = request.method(); match method { &http::Method::GET => return Self::handle_http_get(request, state).await, @@ -42,32 +50,7 @@ impl McpHttpHandler { } } - async fn handle_http_delete( - request: http::Request<&str>, - state: Arc, - ) -> TransportServerResult> { - let headers = request.headers(); - - if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { - let error = SdkError::bad_request() - .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return error_response(StatusCode::BAD_REQUEST, error); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - match session_id { - Some(id) => delete_session(id, state).await, - None => { - let error = SdkError::bad_request().with_message("Bad Request: Session not found"); - error_response(StatusCode::BAD_REQUEST, error) - } - } - } - + /// Processes POST requests for the Streamable HTTP Protocol async fn handle_http_post( request: http::Request<&str>, state: Arc, @@ -121,6 +104,7 @@ impl McpHttpHandler { } } + /// Processes GET requests for the Streamable HTTP Protocol async fn handle_http_get( request: http::Request<&str>, state: Arc, @@ -160,4 +144,31 @@ impl McpHttpHandler { } } } + + /// Processes DELETE requests for the Streamable HTTP Protocol + async fn handle_http_delete( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(id) => delete_session(id, state).await, + None => { + let error = SdkError::bad_request().with_message("Bad Request: Session not found"); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } } diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index 14b8ad7..c4bbc0b 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -11,7 +11,7 @@ use crate::{ use axum::http::HeaderValue; use bytes::Bytes; use futures::stream; -use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE}; +use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE, HOST, ORIGIN}; use http_body::Frame; use http_body_util::StreamBody; use http_body_util::{combinators::BoxBody, BodyExt, Full}; @@ -557,31 +557,49 @@ pub fn error_response( .map_err(|err| TransportServerError::HttpError(err.to_string())) } -// pub fn error_response( -// status_code: StatusCode, -// error: SdkError, -// headers: Option -// ) -> TransportServerResult> { -// let error_string = serde_json::to_string(&error).unwrap_or_default(); -// let body = Full::new(Bytes::from(error_string)) -// .map_err(|err| TransportServerError::HttpError(err.to_string())) -// .boxed(); - -// let mut response = http::Response::builder() -// .status(status_code) -// .header(CONTENT_TYPE, "application/json") -// .body(body) -// .map_err(|err| TransportServerError::HttpError(err.to_string()))?; - -// if let Some(header_map) = headers { -// let response_headers = response.headers_mut(); -// for (key, value) in header_map.into_iter() { -// // Only insert valid headers (Some keys), as `HeaderMap` can contain pseudo-headers -// if let Some(key) = key { -// response_headers.insert(key, value); -// } -// } -// } - -// Ok(response) -// } +// Protect against DNS rebinding attacks by validating Host and Origin headers. +pub(crate) async fn protect_dns_rebinding( + headers: &http::HeaderMap, + state: Arc, +) -> Result<(), SdkError> { + if !state.needs_dns_protection() { + // If protection is not needed, pass the request to the next handler + return Ok(()); + } + + if let Some(allowed_hosts) = state.allowed_hosts.as_ref() { + if !allowed_hosts.is_empty() { + let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { + return Err(SdkError::bad_request().with_message("Invalid Host header: [unknown] ")); + }; + + if !allowed_hosts + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(host)) + { + return Err(SdkError::bad_request() + .with_message(format!("Invalid Host header: \"{host}\" ").as_str())); + } + } + } + + if let Some(allowed_origins) = state.allowed_origins.as_ref() { + if !allowed_origins.is_empty() { + let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { + return Err( + SdkError::bad_request().with_message("Invalid Origin header: [unknown] ") + ); + }; + + if !allowed_origins + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(origin)) + { + return Err(SdkError::bad_request() + .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str())); + } + } + } + + Ok(()) +} From c0a97eb7c23dad46d7704f59a8643a7886fe432d Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 12:04:34 -0300 Subject: [PATCH 11/14] chore: remove session id generator middleware --- crates/rust-mcp-sdk/src/hyper_servers.rs | 1 - .../src/hyper_servers/middlewares.rs | 1 - .../middlewares/session_id_gen.rs | 21 ------------------- .../src/hyper_servers/routes/sse_routes.rs | 20 +++++++----------- 4 files changed, 7 insertions(+), 36 deletions(-) delete mode 100644 crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs delete mode 100644 crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs diff --git a/crates/rust-mcp-sdk/src/hyper_servers.rs b/crates/rust-mcp-sdk/src/hyper_servers.rs index 318720e..87307c0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers.rs @@ -2,7 +2,6 @@ pub mod error; pub mod hyper_runtime; pub mod hyper_server; pub mod hyper_server_core; -mod middlewares; mod routes; mod server; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs deleted file mode 100644 index 612510e..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs +++ /dev/null @@ -1 +0,0 @@ -pub(crate) mod session_id_gen; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs deleted file mode 100644 index 878e3ee..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::mcp_http::McpAppState; -use axum::{ - extract::{Request, State}, - middleware::Next, - response::Response, -}; -use hyper::StatusCode; -use rust_mcp_transport::SessionId; -use std::sync::Arc; - -// Middleware to generate and attach a session ID -pub async fn generate_session_id( - State(state): State>, - mut request: Request, - next: Next, -) -> Result { - let session_id: SessionId = state.id_generator.generate(); - request.extensions_mut().insert(session_id); - // Proceed to the next middleware or handler - Ok(next.run(request).await) -} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 1548aea..cd76ed4 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -2,9 +2,7 @@ use crate::mcp_http::McpAppState; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ - hyper_servers::{ - error::TransportServerResult, middlewares::session_id_gen::generate_session_id, - }, + hyper_servers::error::TransportServerResult, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::mcp_handler::McpServerHandler, @@ -59,15 +57,10 @@ pub fn routes( sse_message_endpoint: &str, ) -> Router> { let sse_message_endpoint = SseMessageEndpoint(sse_message_endpoint.to_string()); - Router::new() - .route( - sse_endpoint, - get(handle_sse).layer(Extension(sse_message_endpoint)), - ) - .route_layer(middleware::from_fn_with_state( - state.clone(), - generate_session_id, - )) + Router::new().route( + sse_endpoint, + get(handle_sse).layer(Extension(sse_message_endpoint)), + ) } /// Handles Server-Sent Events (SSE) connections @@ -81,12 +74,13 @@ pub fn routes( /// # Returns /// * `TransportServerResult` - The SSE response stream or an error pub async fn handle_sse( - Extension(session_id): Extension, Extension(sse_message_endpoint): Extension, State(state): State>, ) -> TransportServerResult { let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; + let session_id: SessionId = state.id_generator.generate(); + let messages_endpoint = SseTransport::::message_endpoint(&sse_message_endpoint, &session_id); From 6dff4e6461556a1a74058352af6e16a1a409365d Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 12:23:12 -0300 Subject: [PATCH 12/14] chore: add new methdo to McpHttpHandler for creatig http::Request --- .../rust-mcp-sdk/src/hyper_servers/routes.rs | 1 - .../src/hyper_servers/routes/sse_routes.rs | 1 - .../routes/streamable_http_routes.rs | 46 ++---------------- .../src/mcp_http/mcp_http_handler.rs | 47 ++++++++++++++++--- 4 files changed, 45 insertions(+), 50 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index cd79580..6bc4411 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -23,7 +23,6 @@ use std::sync::Arc; pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { let router: Router = Router::new() .merge(streamable_http_routes::routes( - state.clone(), server_options.streamable_http_endpoint(), )) .merge({ diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index cd76ed4..19493dc 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -9,7 +9,6 @@ use crate::{ }; use axum::{ extract::State, - middleware, response::{ sse::{Event, KeepAlive}, IntoResponse, Sse, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 1610dc6..6f2e470 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -3,7 +3,6 @@ use crate::mcp_http::{McpAppState, McpHttpHandler}; use axum::routing::get; use axum::{ extract::{Query, State}, - middleware, response::IntoResponse, routing::{delete, post}, Router, @@ -11,7 +10,7 @@ use axum::{ use http::{HeaderMap, Method, Uri}; use std::{collections::HashMap, sync::Arc}; -pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { +pub fn routes(streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) .route(streamable_http_endpoint, post(handle_streamable_http_post)) @@ -26,19 +25,7 @@ pub async fn handle_streamable_http_get( uri: Uri, State(state): State>, ) -> TransportServerResult { - let mut request = http::Request::builder() - .method(Method::GET) - .uri(uri) - .body("") - .unwrap(); //TODO: error handling - - let req_headers = request.headers_mut(); - for (key, value) in headers { - if let Some(k) = key { - req_headers.insert(k, value); - } - } - + let request = McpHttpHandler::create_request(Method::GET, uri, headers, None); let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; let (parts, body) = generic_res.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); @@ -52,19 +39,8 @@ pub async fn handle_streamable_http_post( Query(_params): Query>, payload: String, ) -> TransportServerResult { - let mut request = http::Request::builder() - .method(Method::POST) - .uri(uri) - .body(payload.as_str()) - .unwrap(); //TODO: error handling - - let req_headers = request.headers_mut(); - for (key, value) in headers { - if let Some(k) = key { - req_headers.insert(k, value); - } - } - + let request = + McpHttpHandler::create_request(Method::POST, uri, headers, Some(payload.as_str())); let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; let (parts, body) = generic_res.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); @@ -76,19 +52,7 @@ pub async fn handle_streamable_http_delete( uri: Uri, State(state): State>, ) -> TransportServerResult { - let mut request = http::Request::builder() - .method(Method::DELETE) - .uri(uri) - .body("") - .unwrap(); //TODO: error handling - - let req_headers = request.headers_mut(); - for (key, value) in headers { - if let Some(k) = key { - req_headers.insert(k, value); - } - } - + let request = McpHttpHandler::create_request(Method::DELETE, uri, headers, None); let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; let (parts, body) = generic_res.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 8248afc..6a02f8d 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,5 +1,6 @@ -use std::{os::macos::raw::stat, sync::Arc}; - +use crate::mcp_http::utils::{ + accepts_event_stream, error_response, validate_mcp_protocol_version_header, +}; use crate::{ error::McpSdkError, mcp_http::{ @@ -14,15 +15,47 @@ use crate::{ schema::schema_utils::SdkError, utils::valid_initialize_method, }; -use http::{self, StatusCode}; +use http::{self, HeaderMap, Method, StatusCode, Uri}; use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; - -use crate::mcp_http::utils::{ - accepts_event_stream, error_response, validate_mcp_protocol_version_header, -}; +use std::sync::Arc; pub struct McpHttpHandler {} +impl McpHttpHandler { + /// Creates a new HTTP request with the given method, URI, headers, and optional body. + /// + /// # Arguments + /// + /// * `method` - The HTTP method to use (e.g., GET, POST). + /// * `uri` - The target URI for the request. + /// * `headers` - A map of optional header keys and their corresponding values. + /// * `body` - An optional string slice representing the request body. + /// + /// # Returns + /// + /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body. + /// If the `body` is `None`, an empty string is used as the default. + /// + pub fn create_request( + method: Method, + uri: Uri, + headers: HeaderMap, + body: Option<&str>, + ) -> http::Request<&str> { + let mut request = http::Request::default(); + *request.method_mut() = method; + *request.uri_mut() = uri; + *request.body_mut() = body.unwrap_or_default(); + let req_headers = request.headers_mut(); + for (key, value) in headers { + if let Some(k) = key { + req_headers.insert(k, value); + } + } + request + } +} + impl McpHttpHandler { pub async fn handle_streamable_http( request: http::Request<&str>, From 7cf5f381480fa3525f4d7c40d2220cc1b475db62 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 15:33:07 -0300 Subject: [PATCH 13/14] chore: update sse handlers to use mcp_http_handler --- .../rust-mcp-sdk/src/hyper_servers/routes.rs | 3 +- .../hyper_servers/routes/messages_routes.rs | 48 ++---- .../src/hyper_servers/routes/sse_routes.rs | 140 ++--------------- .../rust-mcp-sdk/src/hyper_servers/server.rs | 13 +- .../src/mcp_http/mcp_http_handler.rs | 52 ++++++- .../src/mcp_http/mcp_http_utils.rs | 141 ++++++++++++++++++ 6 files changed, 220 insertions(+), 177 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index 6bc4411..7dc33f2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -1,5 +1,6 @@ pub mod fallback_routes; pub mod messages_routes; +#[cfg(any(feature = "sse"))] pub mod sse_routes; pub mod streamable_http_routes; @@ -27,10 +28,10 @@ pub fn app_routes(state: Arc, server_options: &HyperServerOptions) )) .merge({ let mut r = Router::new(); + #[cfg(any(feature = "sse"))] if server_options.sse_support { r = r .merge(sse_routes::routes( - state.clone(), server_options.sse_endpoint(), server_options.sse_messages_endpoint(), )) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 6447b35..65490a3 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -1,16 +1,11 @@ use crate::{ - hyper_servers::error::{TransportServerError, TransportServerResult}, - mcp_http::McpAppState, - mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, + hyper_servers::error::TransportServerResult, + mcp_http::{McpAppState, McpHttpHandler}, utils::remove_query_and_hash, }; -use axum::{ - extract::{Query, State}, - response::IntoResponse, - routing::post, - Router, -}; -use std::{collections::HashMap, sync::Arc}; +use axum::{extract::State, response::IntoResponse, routing::post, Router}; +use http::{HeaderMap, Method, Uri}; +use std::sync::Arc; pub fn routes(sse_message_endpoint: &str) -> Router> { Router::new().route( @@ -20,33 +15,14 @@ pub fn routes(sse_message_endpoint: &str) -> Router> { } pub async fn handle_messages( + uri: Uri, + headers: HeaderMap, State(state): State>, - Query(params): Query>, message: String, ) -> TransportServerResult { - let session_id = params - .get("sessionId") - .ok_or(TransportServerError::SessionIdMissing)?; - - // transmit to the readable stream, that transport is reading from - let transmit = - state - .session_store - .get(session_id) - .await - .ok_or(TransportServerError::SessionIdInvalid( - session_id.to_string(), - ))?; - - let transmit = transmit.lock().await; - - transmit - .consume_payload_string(DEFAULT_STREAM_ID, &message) - .await - .map_err(|err| { - tracing::trace!("{}", err); - TransportServerError::StreamIoError(err.to_string()) - })?; - - Ok(axum::http::StatusCode::ACCEPTED) + let request = McpHttpHandler::create_request(Method::POST, uri, headers, Some(&message)); + let generic_response = McpHttpHandler::handle_sse_message(request, state.clone()).await?; + let (parts, body) = generic_response.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 19493dc..e13c724 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,45 +1,11 @@ -use crate::mcp_http::McpAppState; -use crate::mcp_server::error::TransportServerError; -use crate::schema::schema_utils::ClientMessage; -use crate::{ - hyper_servers::error::TransportServerResult, - mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, - mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::mcp_handler::McpServerHandler, -}; -use axum::{ - extract::State, - response::{ - sse::{Event, KeepAlive}, - IntoResponse, Sse, - }, - routing::get, - Extension, Router, -}; -use futures::stream::{self}; -use rust_mcp_transport::{SessionId, SseTransport}; -use std::{convert::Infallible, sync::Arc, time::Duration}; -use tokio::io::{duplex, AsyncBufReadExt, BufReader}; -use tokio_stream::StreamExt; - -const DUPLEX_BUFFER_SIZE: usize = 8192; +use crate::hyper_servers::error::TransportServerResult; +use crate::mcp_http::{McpAppState, McpHttpHandler}; +use axum::{extract::State, response::IntoResponse, routing::get, Extension, Router}; +use std::sync::Arc; #[derive(Clone)] pub struct SseMessageEndpoint(pub String); -/// Creates an initial SSE event that returns the messages endpoint -/// -/// Constructs an SSE event containing the messages endpoint URL with the session ID. -/// -/// # Arguments -/// * `session_id` - The session identifier for the client -/// -/// # Returns -/// * `Result` - The constructed SSE event, infallible -fn initial_event(endpoint: &str) -> Result { - Ok(Event::default().event("endpoint").data(endpoint)) -} - /// Configures the SSE routes for the application /// /// Sets up the Axum router with a single GET route for the specified SSE endpoint. @@ -50,11 +16,7 @@ fn initial_event(endpoint: &str) -> Result { /// /// # Returns /// * `Router>` - An Axum router configured with the SSE route -pub fn routes( - state: Arc, - sse_endpoint: &str, - sse_message_endpoint: &str, -) -> Router> { +pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router> { let sse_message_endpoint = SseMessageEndpoint(sse_message_endpoint.to_string()); Router::new().route( sse_endpoint, @@ -77,91 +39,9 @@ pub async fn handle_sse( State(state): State>, ) -> TransportServerResult { let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; - - let session_id: SessionId = state.id_generator.generate(); - - let messages_endpoint = - SseTransport::::message_endpoint(&sse_message_endpoint, &session_id); - - // readable stream of string to be used in transport - // writing string to read_tx will be received as messages inside the transport and messages will be processed - let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); - - // writable stream to deliver message to the client - let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - - // create a transport for sending/receiving messages - let Ok(transport) = SseTransport::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) else { - return Err(TransportServerError::TransportError( - "Failed to create SSE transport".to_string(), - )); - }; - - let h: Arc = state.handler.clone(); - // create a new server instance with unique session_id and - let server: Arc = server_runtime::create_server_instance( - Arc::clone(&state.server_details), - h, - session_id.to_owned(), - ); - - state - .session_store - .set(session_id.to_owned(), server.clone()) - .await; - - tracing::info!("A new client joined : {}", session_id.to_owned()); - - // Start the server - tokio::spawn(async move { - match server - .start_stream( - Arc::new(transport), - DEFAULT_STREAM_ID, - state.ping_interval, - None, - ) - .await - { - Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), - Err(err) => tracing::info!( - "server {} exited with error : {}", - session_id.to_owned(), - err - ), - }; - - state.session_store.delete(&session_id).await; - }); - - // Initial SSE message to inform the client about the server's endpoint - let initial_event = stream::once(async move { initial_event(&messages_endpoint) }); - - // Construct SSE stream - let reader = BufReader::new(write_rx); - - let message_stream = stream::unfold(reader, |mut reader| async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some((Ok(Event::default().data(trimmed_line)), reader)) - } - Err(_) => None, // Err(e) => Some((Err(e), reader)), - } - }); - - let stream = initial_event.chain(message_stream); - let sse_stream = - Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); - - // Return SSE response with keep-alive - Ok(sse_stream) + let generic_response = + McpHttpHandler::handle_sse_connection(state.clone(), Some(&sse_message_endpoint)).await?; + let (parts, body) = generic_response.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index bfce062..881d4b3 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,7 +1,12 @@ use crate::{ error::SdkResult, id_generator::{FastIdGenerator, UuidGenerator}, - mcp_http::{InMemorySessionStore, McpAppState}, + mcp_http::{ + utils::{ + DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT, + }, + InMemorySessionStore, McpAppState, + }, mcp_server::hyper_runtime::HyperRuntime, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; @@ -27,12 +32,6 @@ use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5; -// Default Server-Sent Events (SSE) endpoint path -const DEFAULT_SSE_ENDPOINT: &str = "/sse"; -// Default MCP Messages endpoint path -const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; -// Default Streamable HTTP endpoint path -const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; /// Configuration struct for the Hyper server /// Used to configure the HyperServer instance. diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 6a02f8d..7dd944d 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,6 +1,11 @@ +#[cfg(any(feature = "sse"))] +use super::utils::handle_sse_connection; use crate::mcp_http::utils::{ - accepts_event_stream, error_response, validate_mcp_protocol_version_header, + accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header, }; +use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID; +use crate::mcp_server::error::TransportServerError; +use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, mcp_http::{ @@ -12,10 +17,11 @@ use crate::{ McpAppState, }, mcp_server::error::TransportServerResult, - schema::schema_utils::SdkError, utils::valid_initialize_method, }; +use bytes::Bytes; use http::{self, HeaderMap, Method, StatusCode, Uri}; +use http_body_util::{BodyExt, Full}; use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::sync::Arc; @@ -57,6 +63,46 @@ impl McpHttpHandler { } impl McpHttpHandler { + #[cfg(any(feature = "sse"))] + pub async fn handle_sse_connection( + state: Arc, + sse_message_endpoint: Option<&str>, + ) -> TransportServerResult> { + handle_sse_connection(state, sse_message_endpoint).await + } + + pub async fn handle_sse_message( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let session_id = + query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?; + + // transmit to the readable stream, that transport is reading from + let transmit = state.session_store.get(&session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + )?; + + let transmit = transmit.lock().await; + let message = *request.body(); + transmit + .consume_payload_string(DEFAULT_STREAM_ID, message) + .await + .map_err(|err| { + tracing::trace!("{}", err); + TransportServerError::StreamIoError(err.to_string()) + })?; + + let body = Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + + http::Response::builder() + .status(StatusCode::ACCEPTED) + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + pub async fn handle_streamable_http( request: http::Request<&str>, state: Arc, @@ -64,7 +110,7 @@ impl McpHttpHandler { // Enforces DNS rebinding protection if required by state. // If protection fails, respond with HTTP 403 Forbidden. if state.needs_dns_protection() { - if let Err(error) = protect_dns_rebinding(&request.headers(), state.clone()).await { + if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await { return error_response(StatusCode::FORBIDDEN, error); } } diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index c4bbc0b..9fb7758 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -24,10 +24,32 @@ use std::sync::Arc; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio_stream::StreamExt; +// Default Server-Sent Events (SSE) endpoint path +pub(crate) const DEFAULT_SSE_ENDPOINT: &str = "/sse"; +// Default MCP Messages endpoint path +pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; +// Default Streamable HTTP endpoint path +pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; const DUPLEX_BUFFER_SIZE: usize = 8192; pub type GenericBody = BoxBody; +/// Creates an initial SSE event that returns the messages endpoint +/// +/// Constructs an SSE event containing the messages endpoint URL with the session ID. +/// +/// # Arguments +/// * `session_id` - The session identifier for the client +/// +/// # Returns +/// * `Result` - The constructed SSE event, infallible +fn initial_sse_event(endpoint: &str) -> Result { + Ok(SseEvent::default() + .with_event("endpoint") + .with_data(endpoint.to_string()) + .as_bytes()) +} + async fn create_sse_stream( runtime: Arc, session_id: SessionId, @@ -603,3 +625,122 @@ pub(crate) async fn protect_dns_rebinding( Ok(()) } + +pub fn query_param<'a>(request: &'a http::Request<&str>, key: &str) -> Option { + request.uri().query().and_then(|query| { + for pair in query.split('&') { + let mut split = pair.splitn(2, '='); + let k = split.next()?; + let v = split.next().unwrap_or(""); + if k == key { + return Some(v.to_string()); + } + } + None + }) +} + +#[cfg(any(feature = "sse"))] +pub(crate) async fn handle_sse_connection( + state: Arc, + sse_message_endpoint: Option<&str>, +) -> TransportServerResult> { + let session_id: SessionId = state.id_generator.generate(); + + let sse_message_endpoint = sse_message_endpoint.unwrap_or(DEFAULT_MESSAGES_ENDPOINT); + let messages_endpoint = + SseTransport::::message_endpoint(sse_message_endpoint, &session_id); + + // readable stream of string to be used in transport + // writing string to read_tx will be received as messages inside the transport and messages will be processed + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + // / create a transport for sending/receiving messages + let Ok(transport) = SseTransport::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) else { + return Err(TransportServerError::TransportError( + "Failed to create SSE transport".to_string(), + )); + }; + + let h: Arc = state.handler.clone(); + // create a new server instance with unique session_id and + let server: Arc = server_runtime::create_server_instance( + Arc::clone(&state.server_details), + h, + session_id.to_owned(), + ); + + state + .session_store + .set(session_id.to_owned(), server.clone()) + .await; + + tracing::info!("A new client joined : {}", session_id.to_owned()); + + // Start the server + tokio::spawn(async move { + match server + .start_stream( + Arc::new(transport), + DEFAULT_STREAM_ID, + state.ping_interval, + None, + ) + .await + { + Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), + Err(err) => tracing::info!( + "server {} exited with error : {}", + session_id.to_owned(), + err + ), + }; + + state.session_store.delete(&session_id).await; + }); + + // Initial SSE message to inform the client about the server's endpoint + let initial_sse_event = stream::once(async move { initial_sse_event(&messages_endpoint) }); + + // Construct SSE stream + let reader = BufReader::new(write_rx); + + let message_stream = stream::unfold(reader, |mut reader| async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + Some(( + Ok(SseEvent::default().with_data(trimmed_line).as_bytes()), + reader, + )) + } + Err(_) => None, // Err(e) => Some((Err(e), reader)), + } + }); + + let stream = initial_sse_event.chain(message_stream); + + // create a stream body + let streaming_body: GenericBody = + http_body_util::BodyExt::boxed(StreamBody::new(stream.map(|res| res.map(Frame::data)))); + + let response = http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "text/event-stream") + .header(CONNECTION, "keep-alive") + .body(streaming_body) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + Ok(response) +} From fca47584eaa4876265d269fe13ec65568eb898a3 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 12 Oct 2025 16:21:32 -0300 Subject: [PATCH 14/14] chore: update feature flags --- crates/rust-mcp-sdk/Cargo.toml | 14 ++--- .../rust-mcp-sdk/src/hyper_servers/routes.rs | 4 +- .../src/mcp_http/mcp_http_handler.rs | 57 ++++++++++++++++++- .../src/mcp_http/mcp_http_utils.rs | 17 +++++- 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 85371fe..be70e07 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -29,13 +29,13 @@ tokio-stream = { workspace = true, optional = true } axum-server = { version = "0.7", features = [], optional = true } tracing.workspace = true base64.workspace = true +bytes.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } -http = "1.3.1" -http-body-util = "0.1.3" -http-body = "1.0.1" -bytes.workspace = true +http = { version ="1.3", optional = true } +http-body-util = { version ="0.1", optional = true } +http-body = { version ="1.0", optional = true } [dev-dependencies] wiremock = "0.5" @@ -65,13 +65,13 @@ default = [ "2025_06_18", ] # All features enabled by default -sse = ["rust-mcp-transport/sse"] -streamable-http = ["rust-mcp-transport/streamable-http"] +sse = ["rust-mcp-transport/sse","http","http-body","http-body-util"] +streamable-http = ["rust-mcp-transport/streamable-http","http","http-body","http-body-util"] stdio = ["rust-mcp-transport/stdio"] server = [] # Server feature client = [] # Client feature -hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream"] +hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream","http","http-body","http-body-util"] ssl = ["axum-server/tls-rustls"] tls-no-provider = ["axum-server/tls-rustls-no-provider"] macros = ["rust-mcp-macros/sdk"] diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index 7dc33f2..4ae274b 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -1,6 +1,6 @@ pub mod fallback_routes; pub mod messages_routes; -#[cfg(any(feature = "sse"))] +#[cfg(feature = "sse")] pub mod sse_routes; pub mod streamable_http_routes; @@ -28,7 +28,7 @@ pub fn app_routes(state: Arc, server_options: &HyperServerOptions) )) .merge({ let mut r = Router::new(); - #[cfg(any(feature = "sse"))] + #[cfg(feature = "sse")] if server_options.sse_support { r = r .merge(sse_routes::routes( diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 7dd944d..fb830ae 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,4 +1,4 @@ -#[cfg(any(feature = "sse"))] +#[cfg(feature = "sse")] use super::utils::handle_sse_connection; use crate::mcp_http::utils::{ accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header, @@ -63,7 +63,19 @@ impl McpHttpHandler { } impl McpHttpHandler { - #[cfg(any(feature = "sse"))] + /// Handles an MCP connection using the SSE (Server-Sent Events) transport. + /// + /// This function serves as the entry point for initializing and managing a client connection + /// over SSE when the `sse` feature is enabled. + /// + /// # Arguments + /// * `state` - Shared application state required to manage the MCP session. + /// * `sse_message_endpoint` - Optional message endpoint to override the default SSE route (default: `/messages` ). + /// + /// + /// # Features + /// This function is only available when the `sse` feature is enabled. + #[cfg(feature = "sse")] pub async fn handle_sse_connection( state: Arc, sse_message_endpoint: Option<&str>, @@ -71,6 +83,26 @@ impl McpHttpHandler { handle_sse_connection(state, sse_message_endpoint).await } + /// Handles incoming MCP messages from the client after an SSE connection is established. + /// + /// This function processes a message sent by the client as part of an active SSE session. It: + /// - Extracts the `sessionId` from the request query parameters. + /// - Locates the corresponding session's transmit channel. + /// - Forwards the incoming message payload to the MCP transport stream for consumption. + /// # Arguments + /// * `request` - The HTTP request containing the message body and query parameters (including `sessionId`). + /// * `state` - Shared application state, including access to the session store. + /// + /// # Returns + /// * `TransportServerResult>`: + /// - Returns a `202 Accepted` HTTP response if the message is successfully forwarded. + /// - Returns an error if the session ID is missing, invalid, or if any I/O issues occur while processing the message. + /// + /// # Errors + /// - `SessionIdMissing`: if the `sessionId` query parameter is not present. + /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store. + /// - `StreamIoError`: if an error occurs while writing to the stream. + /// - `HttpError`: if constructing the HTTP response fails. pub async fn handle_sse_message( request: http::Request<&str>, state: Arc, @@ -103,6 +135,27 @@ impl McpHttpHandler { .map_err(|err| TransportServerError::HttpError(err.to_string())) } + /// Handles incoming MCP messages over the StreamableHTTP transport. + /// + /// It supports `GET`, `POST`, and `DELETE` methods for handling streaming operations, and performs optional + /// DNS rebinding protection if it is configured. + /// + /// # Arguments + /// * `request` - The HTTP request from the client, including method, headers, and optional body. + /// * `state` - Shared application state, including configuration and session management. + /// + /// # Behavior + /// - If DNS rebinding protection is enabled via the app state, the function checks the request headers. + /// If dns protection fails, a `403 Forbidden` response is returned. + /// - Dispatches the request to method-specific handlers based on the HTTP method: + /// - `GET` → `handle_http_get` + /// - `POST` → `handle_http_post` + /// - `DELETE` → `handle_http_delete` + /// - Returns `405 Method Not Allowed` for unsupported methods. + /// + /// # Returns + /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation. + /// pub async fn handle_streamable_http( request: http::Request<&str>, state: Arc, diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index 9fb7758..06443e9 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -626,7 +626,20 @@ pub(crate) async fn protect_dns_rebinding( Ok(()) } -pub fn query_param<'a>(request: &'a http::Request<&str>, key: &str) -> Option { +/// Extracts the value of a query parameter from an HTTP request by key. +/// +/// This function parses the query string from the request URI and searches +/// for the specified key. If found, it returns the corresponding value as a `String`. +/// +/// # Arguments +/// * `request` - The HTTP request containing the URI with the query string. +/// * `key` - The name of the query parameter to retrieve. +/// +/// # Returns +/// * `Some(String)` containing the value of the query parameter if found. +/// * `None` if the query string is missing or the key is not present. +/// +pub fn query_param(request: &http::Request<&str>, key: &str) -> Option { request.uri().query().and_then(|query| { for pair in query.split('&') { let mut split = pair.splitn(2, '='); @@ -640,7 +653,7 @@ pub fn query_param<'a>(request: &'a http::Request<&str>, key: &str) -> Option, sse_message_endpoint: Option<&str>,