diff --git a/.release-manifest.json b/.release-manifest.json index e8ad288..db381e1 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,15 @@ { - "crates/rust-mcp-sdk": "0.5.1", - "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.4.1", - "examples/hello-world-mcp-server": "0.1.25", - "examples/hello-world-mcp-server-core": "0.1.16", - "examples/simple-mcp-client": "0.1.25", - "examples/simple-mcp-client-core": "0.1.25", - "examples/hello-world-server-core-streamable-http": "0.1.16", - "examples/hello-world-server-streamable-http": "0.1.25", - "examples/simple-mcp-client-core-sse": "0.1.16", - "examples/simple-mcp-client-sse": "0.1.16" + "crates/rust-mcp-sdk": "0.7.0", + "crates/rust-mcp-macros": "0.5.2", + "crates/rust-mcp-transport": "0.6.0", + "examples/hello-world-mcp-server-stdio": "0.1.29", + "examples/hello-world-mcp-server-stdio-core": "0.1.20", + "examples/simple-mcp-client-stdio": "0.1.29", + "examples/simple-mcp-client-stdio-core": "0.1.29", + "examples/hello-world-server-streamable-http-core": "0.1.20", + "examples/hello-world-server-streamable-http": "0.1.32", + "examples/simple-mcp-client-sse-core": "0.1.20", + "examples/simple-mcp-client-sse": "0.1.23", + "examples/simple-mcp-client-streamable-http": "0.1.1", + "examples/simple-mcp-client-streamable-http-core": "0.1.1" } diff --git a/Cargo.lock b/Cargo.lock index df081df..0acb30d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", @@ -84,9 +84,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.13.3" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba" +checksum = "94b8ff6c09cd57b16da53641caa860168b88c172a5ee163b0288d3d6eea12786" dependencies = [ "aws-lc-sys", "zeroize", @@ -94,9 +94,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" +checksum = "0e44d16778acaf6a9ec9899b92cebd65580b83f685446bf2e1f5d3d732f99dcd" dependencies = [ "bindgen", "cc", @@ -118,7 +118,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "itoa", "matchit", @@ -170,7 +170,7 @@ dependencies = [ "fs-err", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "pin-project-lite", "rustls", @@ -216,32 +216,29 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.5" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ "bitflags", "cexpr", "clang-sys", "itertools", - "lazy_static", - "lazycell", "log", "prettyplease", "proc-macro2", "quote", "regex", - "rustc-hash 1.1.0", + "rustc-hash", "shlex", "syn", - "which", ] [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" [[package]] name = "bumpalo" @@ -257,10 +254,11 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.32" +version = "1.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2352e5597e9c544d5e6d9c95190d5d27738ade584fa8db0a16e130e5c2b5296e" +checksum = "65193589c6404eb80b450d618eaf9a2cafaaafd57ecce47370519ef674a7bd44" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -277,9 +275,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "cfg_aliases" @@ -381,9 +379,9 @@ checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" dependencies = [ "powerfmt", ] @@ -426,16 +424,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" -[[package]] -name = "errno" -version = "0.3.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" -dependencies = [ - "libc", - "windows-sys 0.60.2", -] - [[package]] name = "event-listener" version = "2.5.3" @@ -451,6 +439,12 @@ dependencies = [ "instant", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fd99930f64d146689264c637b5af2f0233a933bef0d8570e2526bf9e083192d" + [[package]] name = "fnv" version = "1.0.7" @@ -459,18 +453,18 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] [[package]] name = "fs-err" -version = "3.1.1" +version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d7be93788013f265201256d58f04936a8079ad5dc898743aa20525f503b683" +checksum = "44f150ffc8782f35521cec2b23727707cb4045706ba3c854e86bef66b3a8cdbd" dependencies = [ "autocfg", "tokio", @@ -626,7 +620,7 @@ dependencies = [ "js-sys", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasi 0.14.7+wasi-0.2.4", "wasm-bindgen", ] @@ -687,8 +681,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] -name = "hello-world-mcp-server" -version = "0.1.25" +name = "hello-world-mcp-server-stdio" +version = "0.1.29" dependencies = [ "async-trait", "futures", @@ -701,8 +695,8 @@ dependencies = [ ] [[package]] -name = "hello-world-mcp-server-core" -version = "0.1.16" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.20" dependencies = [ "async-trait", "futures", @@ -713,8 +707,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-core-streamable-http" -version = "0.1.16" +name = "hello-world-server-streamable-http" +version = "0.1.32" dependencies = [ "async-trait", "futures", @@ -727,8 +721,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-streamable-http" -version = "0.1.25" +name = "hello-world-server-streamable-http-core" +version = "0.1.20" dependencies = [ "async-trait", "futures", @@ -746,15 +740,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" -[[package]] -name = "home" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" -dependencies = [ - "windows-sys 0.59.0", -] - [[package]] name = "http" version = "0.2.12" @@ -870,13 +855,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", @@ -884,6 +870,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -896,7 +883,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http 1.3.1", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "rustls", "rustls-pki-types", @@ -908,9 +895,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ "base64 0.22.1", "bytes", @@ -919,7 +906,7 @@ dependencies = [ "futures-util", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "ipnet", "libc", "percent-encoding", @@ -1018,9 +1005,9 @@ dependencies = [ [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -1039,9 +1026,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.10.0" +version = "2.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" +checksum = "92119844f513ffa41556430369ab02c295a3578af21cf945caa3e9e0c2481ac3" dependencies = [ "equivalent", "hashbrown", @@ -1064,9 +1051,9 @@ dependencies = [ [[package]] name = "io-uring" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" dependencies = [ "bitflags", "cfg-if", @@ -1091,9 +1078,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -1106,9 +1093,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ "getrandom 0.3.3", "libc", @@ -1116,9 +1103,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "852f13bec5eba4ba9afbeb93fd7c13fe56147f055939ae21c43a29a0ecb2702e" dependencies = [ "once_cell", "wasm-bindgen", @@ -1130,12 +1117,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "libc" version = "0.2.175" @@ -1152,12 +1133,6 @@ dependencies = [ "windows-targets 0.53.3", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "litemap" version = "0.8.0" @@ -1182,9 +1157,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" [[package]] name = "lru-slab" @@ -1194,11 +1169,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -1267,12 +1242,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -1306,12 +1280,6 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking" version = "2.2.1" @@ -1343,9 +1311,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pin-project-lite" @@ -1361,9 +1329,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "potential_utf" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" dependencies = [ "zerovec", ] @@ -1385,9 +1353,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.36" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1395,9 +1363,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.97" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -1420,19 +1388,19 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", - "socket2 0.5.10", - "thiserror 2.0.14", + "socket2 0.6.0", + "thiserror 2.0.16", "tokio", "tracing", "web-time", @@ -1440,20 +1408,20 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.12" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes", "getrandom 0.3.3", "lru-slab", "rand 0.9.2", "ring", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.14", + "thiserror 2.0.16", "tinyvec", "tracing", "web-time", @@ -1461,16 +1429,16 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.0", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -1569,47 +1537,32 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - -[[package]] -name = "regex-syntax" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" [[package]] name = "reqwest" @@ -1626,7 +1579,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-rustls", "hyper-util", "js-sys", @@ -1677,7 +1630,7 @@ dependencies = [ [[package]] name = "rust-mcp-macros" -version = "0.5.1" +version = "0.5.2" dependencies = [ "proc-macro2", "quote", @@ -1689,9 +1642,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0e71aee61257cd3d4a78fdc10c92c29e7a55c4f767119ffdafd837bb5e5cb9a" +checksum = "0bb65fd293dbbfabaacba1512b3948cdd9bf31ad1f2c0fed4962052b590c5c44" dependencies = [ "serde", "serde_json", @@ -1699,30 +1652,32 @@ dependencies = [ [[package]] name = "rust-mcp-sdk" -version = "0.5.1" +version = "0.7.0" dependencies = [ "async-trait", "axum", "axum-server", + "base64 0.22.1", "futures", - "hyper 1.6.0", + "hyper 1.7.0", "reqwest", "rust-mcp-macros", "rust-mcp-schema", "rust-mcp-transport", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.16", "tokio", "tokio-stream", "tracing", "tracing-subscriber", "uuid", + "wiremock", ] [[package]] name = "rust-mcp-transport" -version = "0.4.1" +version = "0.6.0" dependencies = [ "async-trait", "bytes", @@ -1731,7 +1686,7 @@ dependencies = [ "rust-mcp-schema", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.16", "tokio", "tokio-stream", "tracing", @@ -1744,31 +1699,12 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.59.0", -] - [[package]] name = "rustls" version = "0.23.31" @@ -1805,9 +1741,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.4" +version = "0.103.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" dependencies = [ "aws-lc-rs", "ring", @@ -1835,18 +1771,28 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.225" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516" dependencies = [ "proc-macro2", "quote", @@ -1855,24 +1801,26 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.142" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "030fedb782600dcbd6f02d479bf0d817ac3bb40d644745b769d6a96bc3afc5a7" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] name = "serde_path_to_error" -version = "0.1.17" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" dependencies = [ "itoa", "serde", + "serde_core", ] [[package]] @@ -1923,8 +1871,24 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client" -version = "0.1.25" +name = "simple-mcp-client-sse" +version = "0.1.23" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "simple-mcp-client-sse-core" +version = "0.1.20" dependencies = [ "async-trait", "colored", @@ -1932,13 +1896,15 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core" -version = "0.1.25" +name = "simple-mcp-client-stdio" +version = "0.1.29" dependencies = [ "async-trait", "colored", @@ -1946,13 +1912,27 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.16", "tokio", ] [[package]] -name = "simple-mcp-client-core-sse" -version = "0.1.16" +name = "simple-mcp-client-stdio-core" +version = "0.1.29" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-streamable-http" +version = "0.1.1" dependencies = [ "async-trait", "colored", @@ -1960,15 +1940,15 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.16", "tokio", "tracing", "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-sse" -version = "0.1.16" +name = "simple-mcp-client-streamable-http-core" +version = "0.1.1" dependencies = [ "async-trait", "colored", @@ -1976,7 +1956,7 @@ dependencies = [ "rust-mcp-sdk", "serde", "serde_json", - "thiserror 2.0.14", + "thiserror 2.0.16", "tokio", "tracing", "tracing-subscriber", @@ -2028,9 +2008,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -2068,11 +2048,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.14" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b0949c3a6c842cbde3f1686d6eea5a010516deb7085f79db747562d4102f41e" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" dependencies = [ - "thiserror-impl 2.0.14", + "thiserror-impl 2.0.16", ] [[package]] @@ -2088,9 +2068,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.14" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc5b44b4ab9c2fdd0e0512e6bece8388e214c0749f5862b114cc5b7a25daf227" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" dependencies = [ "proc-macro2", "quote", @@ -2108,12 +2088,11 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "83bde6f1ec10e72d583d91623c939f623002284ef622b87de38cfd546cbf2031" dependencies = [ "deranged", - "itoa", "num-conv", "powerfmt", "serde", @@ -2123,15 +2102,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.22" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", @@ -2149,9 +2128,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] @@ -2195,9 +2174,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +checksum = "05f63835928ca123f1bef57abbcd23bb2ba0ac9ae1235f1e65bda0d06e7786bd" dependencies = [ "rustls", "tokio", @@ -2319,14 +2298,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -2349,9 +2328,9 @@ checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" [[package]] name = "untrusted" @@ -2361,9 +2340,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.4" +version = "2.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", "idna", @@ -2379,9 +2358,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.18.0" +version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f33196643e165781c20a5ead5582283a7dacbb87855d867fbc2df3f81eddc1be" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ "getrandom 0.3.3", "js-sys", @@ -2429,30 +2408,40 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.2+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "ab10a69fbd0a177f5f649ad4d8d3305499c42bab9aef2f7ff592d0ec8f833819" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +checksum = "0bb702423545a6007bbc368fde243ba47ca275e549c8a28617f56f6ba53b1d1c" dependencies = [ "bumpalo", "log", @@ -2464,9 +2453,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.50" +version = "0.4.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" +checksum = "a0b221ff421256839509adbb55998214a70d829d3a28c69b4a6672e9d2a42f67" dependencies = [ "cfg-if", "js-sys", @@ -2477,9 +2466,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "fc65f4f411d91494355917b605e1480033152658d71f722a90647f56a70c88a0" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2487,9 +2476,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "ffc003a991398a8ee604a401e194b6b3a39677b3173d6e74495eb51b82e99a32" dependencies = [ "proc-macro2", "quote", @@ -2500,9 +2489,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "293c37f4efa430ca14db3721dfbe48d8c33308096bd44d80ebaa775ab71ba1cf" dependencies = [ "unicode-ident", ] @@ -2522,9 +2511,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "fbe734895e869dc429d78c4b433f8d17d95f8d05317440b4fad5ab2d33e596dc" dependencies = [ "js-sys", "wasm-bindgen", @@ -2549,40 +2538,6 @@ dependencies = [ "rustls-pki-types", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-link" version = "0.1.3" @@ -2768,13 +2723,10 @@ dependencies = [ ] [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags", -] +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "writeable" @@ -2808,18 +2760,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index a85b5a7..edb7e28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,21 +4,24 @@ members = [ "crates/rust-mcp-macros", "crates/rust-mcp-sdk", "crates/rust-mcp-transport", - "examples/simple-mcp-client", - "examples/simple-mcp-client-core", - "examples/hello-world-mcp-server", - "examples/hello-world-mcp-server-core", + "examples/simple-mcp-client-stdio", + "examples/simple-mcp-client-stdio-core", + "examples/hello-world-mcp-server-stdio", + "examples/hello-world-mcp-server-stdio-core", "examples/hello-world-server-streamable-http", - "examples/hello-world-server-core-streamable-http", + "examples/hello-world-server-streamable-http-core", "examples/simple-mcp-client-sse", - "examples/simple-mcp-client-core-sse", + "examples/simple-mcp-client-sse-core", + "examples/simple-mcp-client-streamable-http", + "examples/simple-mcp-client-streamable-http-core", + ] [workspace.dependencies] # Workspace member crates -rust-mcp-transport = { version = "0.4.1", path = "crates/rust-mcp-transport", default-features = false } +rust-mcp-transport = { version = "0.6.0", path = "crates/rust-mcp-transport", default-features = false } rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false } -rust-mcp-macros = { version = "0.5.1", path = "crates/rust-mcp-macros", default-features = false } +rust-mcp-macros = { version = "0.5.2", path = "crates/rust-mcp-macros", default-features = false } # External crates rust-mcp-schema = { version = "0.7", default-features = false } @@ -39,7 +42,7 @@ tracing-subscriber = { version = "0.3", features = [ "std", "fmt", ] } - +base64 = "0.22" axum = "0.8" rustls = "0.23" tokio-rustls = "0.26" diff --git a/README.md b/README.md index ef5b4ed..2c70c3e 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -32,15 +32,14 @@ This project supports following transports: πŸš€ The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - βœ… Streamable HTTP Support for MCP Servers - βœ… DNS Rebinding Protection - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients -- ⬜ Resumability -- ⬜ Authentication / Oauth +- βœ… Streamable HTTP Support for MCP Clients +- βœ… Resumability +- ⬜ Oauth Authentication **⚠️** Project is currently under development and should be used at your own risk. @@ -49,7 +48,9 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) +- [Macros](#macros) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) - [Security Considerations](#security-considerations) @@ -110,7 +111,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -153,6 +154,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); @@ -180,7 +182,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +193,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +207,7 @@ impl ServerHandler for MyServerHandler { --- -πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -283,6 +285,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -294,8 +298,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +πŸ‘‰ see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -306,6 +384,116 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + + +## Macros +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. + +> To use these macros, ensure the `macros` feature is enabled in your Cargo.toml. + +### mcp_tool +`mcp_tool` is a procedural macro attribute that helps generating rust_mcp_schema::Tool from a struct. + +Usage example: +```rust +#[mcp_tool( + name = "move_file", + title="Move File", + description = concat!("Move or rename files and directories. Can move files between directories ", +"and rename them in a single operation. If the destination exists, the ", +"operation will fail. Works across different directories and can be used ", +"for simple renaming within the same directory. ", +"Both source and destination must be within allowed directories."), + destructive_hint = false, + idempotent_hint = false, + open_world_hint = false, + read_only_hint = false +)] +#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] +pub struct MoveFileTool { + /// The source path of the file to move. + pub source: String, + /// The destination path to move the file to. + pub destination: String, +} + +// Now we can call `tool()` method on it to get a Tool instance +let rust_mcp_sdk::schema::Tool = MoveFileTool::tool(); + +``` + +πŸ’» For a real-world example, check out any of the tools available at: https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools + + +### tool_box +`tool_box` generates an enum from a provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. + +It accepts an array of tools and generates an enum where each tool becomes a variant of the enum. + +Generated enum has a `tools()` function that returns a `Vec` , and a `TryFrom` trait implementation that could be used to convert a ToolRequest into a Tool instance. + +Usage example: +```rust + // Accepts an array of tools and generates an enum named `FileSystemTools`, + // where each tool becomes a variant of the enum. + tool_box!(FileSystemTools, [ReadFileTool, MoveFileTool, SearchFilesTool]); + + // now in the app, we can use the FileSystemTools, like: + let all_tools: Vec = FileSystemTools::tools(); +``` + +πŸ’» To see a real-world example of that please see : +- `tool_box` macro usage: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs) +- using `tools()` in list tools request : [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L67) +- using `try_from` in call tool_request: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L100) + + + +### mcp_elicit +The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: + +- A `message()` method returning the elicitation message as a string. +- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. +- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. + +### Attributes + +- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. + +Usage example: +```rust +// A struct that could be used to send elicit request and get the input from the user +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} + +// send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance +let result: ElicitResult = server + .elicit_input(UserInfo::message(), UserInfo::requested_schema()) + .await?; + +// Create a UserInfo instance using data provided by the user on the client side +let user_info = UserInfo::from_content_map(result.content)?; + +``` + +πŸ’» For mre info please see : +- https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros ## Getting Started @@ -337,6 +525,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -344,9 +533,19 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -356,12 +555,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -373,17 +566,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -395,6 +577,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -416,9 +609,15 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. + +- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. + #### MCP Protocol Versions with Corresponding Features @@ -449,9 +648,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -464,7 +663,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` @@ -477,10 +676,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +708,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK @@ -526,6 +725,11 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | + + + + diff --git a/crates/rust-mcp-macros/CHANGELOG.md b/crates/rust-mcp-macros/CHANGELOG.md index a7b5306..69b3059 100644 --- a/crates/rust-mcp-macros/CHANGELOG.md +++ b/crates/rust-mcp-macros/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.5.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-macros-v0.5.1...rust-mcp-macros-v0.5.2) (2025-09-19) + + +### πŸš€ Features + +* Add elicitation macros and add elicit_input() method ([#99](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/99)) ([3ab5fe7](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/3ab5fe73aaa10de2b5b23caee357ac15b37c845f)) + ## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-macros-v0.5.0...rust-mcp-macros-v0.5.1) (2025-08-12) diff --git a/crates/rust-mcp-macros/Cargo.toml b/crates/rust-mcp-macros/Cargo.toml index 0dfdc56..9c2dd5a 100644 --- a/crates/rust-mcp-macros/Cargo.toml +++ b/crates/rust-mcp-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-macros" -version = "0.5.1" +version = "0.5.2" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "A procedural macro that derives the MCPToolSchema implementation for structs or enums, generating a tool_input_schema function used with rust_mcp_schema::Tool." diff --git a/crates/rust-mcp-macros/README.md b/crates/rust-mcp-macros/README.md index 92da2c3..fc463cd 100644 --- a/crates/rust-mcp-macros/README.md +++ b/crates/rust-mcp-macros/README.md @@ -1,5 +1,8 @@ # rust-mcp-macros. + +## mcp_tool Macro + A procedural macro, part of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) ecosystem, to generate `rust_mcp_schema::Tool` instance from a struct. The `mcp_tool` macro generates an implementation for the annotated struct that includes: @@ -80,11 +83,7 @@ fn main() { ``` ---- - Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) , a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest! - ---- **Note**: The following attributes are available only in version `2025_03_26` and later of the MCP Schema, and their values will be used in the [annotations](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5557) attribute of the *[Tool struct](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5554-L5566). @@ -93,3 +92,106 @@ fn main() { - `idempotent_hint` - `open_world_hint` - `read_only_hint` + + + + + +## mcp_elicit Macro + +The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: + +- A `message()` method returning the elicitation message as a string. +- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. +- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. + +### Attributes + +- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. + +### Supported Field Types + +- `String`: Maps to `ElicitResultContentValue::String`. +- `bool`: Maps to `ElicitResultContentValue::Boolean`. +- `i32`: Maps to `ElicitResultContentValue::Integer` (with bounds checking). +- `i64`: Maps to `ElicitResultContentValue::Integer`. +- `enum` Only simple enums are supported. The enum must implement the FromStr trait. +- `Option`: Supported for any of the above types, mapping to `None` if the field is missing. + + +### Usage Example + +```rust +use rust_mcp_sdk::macros::{mcp_elicit, JsonSchema}; +use rust_mcp_sdk::schema::RpcError; +use std::str::FromStr; + +// Simple enum with FromStr trait implemented +#[derive(JsonSchema, Debug)] +pub enum Colors { + #[json_schema(title = "Green Color")] + Green, + #[json_schema(title = "Red Color")] + Red, +} +impl FromStr for Colors { + type Err = RpcError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "green" => Ok(Colors::Green), + "red" => Ok(Colors::Red), + _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), + } + } +} + +// A struct that could be used to send elicit request and get the input from the user +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + + /// Email address of the user + #[json_schema(title = "Email", format = "email")] + pub email: Option, + + /// The user's age in years + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} + + // .... + // ....... + // ........... + + // send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance + + let result: ElicitResult = server + .elicit_input(UserInfo::message(), UserInfo::requested_schema()) + .await?; + + // Create a UserInfo instance using data provided by the user on the client side + let user_info = UserInfo::from_content_map(result.content)?; + + +``` + +--- + + Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk), a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest! + +--- diff --git a/crates/rust-mcp-macros/src/lib.rs b/crates/rust-mcp-macros/src/lib.rs index 35d6e55..473792c 100644 --- a/crates/rust-mcp-macros/src/lib.rs +++ b/crates/rust-mcp-macros/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr, - ExprLit, Fields, Lit, Meta, Token, + ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments, Token, Type, }; use utils::{is_option, renamed_field, type_to_json_schema}; @@ -45,6 +45,8 @@ struct McpToolMacroAttributes { use syn::parse::ParseStream; +use crate::utils::{generate_enum_parse, is_enum}; + struct ExprList { exprs: Punctuated, } @@ -246,6 +248,66 @@ impl Parse for McpToolMacroAttributes { } } +struct McpElicitationAttributes { + message: Option, +} + +impl Parse for McpElicitationAttributes { + fn parse(attributes: syn::parse::ParseStream) -> syn::Result { + let mut instance = Self { message: None }; + let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?; + for meta in meta_list { + if let Meta::NameValue(meta_name_value) = meta { + let ident = meta_name_value.path.get_ident().unwrap(); + let ident_str = ident.to_string(); + if ident_str.as_str() == "message" { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + Expr::Macro(expr_macro) => { + let mac = &expr_macro.mac; + if mac.path.is_ident("concat") { + let args: ExprList = syn::parse2(mac.tokens.clone())?; + let mut result = String::new(); + for expr in args.exprs { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = expr + { + result.push_str(&lit_str.value()); + } else { + return Err(Error::new_spanned( + expr, + "Only string literals are allowed inside concat!()", + )); + } + } + result + } else { + return Err(Error::new_spanned( + expr_macro, + "Only concat!(...) is supported here", + )); + } + } + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a string literal or concat!(...)", + )); + } + }; + instance.message = Some(value) + } + } + } + Ok(instance) + } +} + /// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct. /// /// The `mcp_tool` macro generates an implementation for the annotated struct that includes: @@ -387,7 +449,7 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { let output = quote! { impl #input_ident { - /// Returns the name of the tool as a string. + /// Returns the name of the tool as a String. pub fn tool_name() -> String { #tool_name.to_string() } @@ -404,7 +466,7 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { .iter() .filter_map(|item| item.as_str().map(String::from)) .collect(), - None => Vec::new(), // Default to an empty vector if "required" is missing or not an array + None => Vec::new(), }; let properties: Option< @@ -440,6 +502,303 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { TokenStream::from(output) } +#[proc_macro_attribute] +pub fn mcp_elicit(attributes: TokenStream, input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let input_ident = &input.ident; + + // Conditionally select the path + let base_crate = if cfg!(feature = "sdk") { + quote! { rust_mcp_sdk::schema } + } else { + quote! { rust_mcp_schema } + }; + + let macro_attributes = parse_macro_input!(attributes as McpElicitationAttributes); + let message = macro_attributes.message.unwrap_or_default(); + + // Generate field assignments for from_content_map() + let field_assignments = match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => { + let assignments = fields.named.iter().map(|field| { + let field_attrs = &field.attrs; + let field_ident = &field.ident; + let renamed_field = renamed_field(field_attrs); + let field_name = renamed_field.unwrap_or_else(|| field_ident.as_ref().unwrap().to_string()); + let field_type = &field.ty; + + let type_check = if is_option(field_type) { + // Extract inner type for Option + let inner_type = match field_type { + Type::Path(type_path) => { + let segment = type_path.path.segments.last().unwrap(); + if segment.ident == "Option" { + match &segment.arguments { + PathArguments::AngleBracketed(args) => { + match args.args.first().unwrap() { + GenericArgument::Type(ty) => ty, + _ => panic!("Expected type argument in Option"), + } + } + _ => panic!("Invalid Option type"), + } + } else { + panic!("Expected Option type"); + } + } + _ => panic!("Expected Option type"), + }; + // Determine the match arm based on the inner type at compile time + let (inner_type_ident, match_pattern, conversion) = match inner_type { + Type::Path(type_path) if type_path.path.is_ident("String") => ( + quote! { String }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { s.clone() } + ), + Type::Path(type_path) if type_path.path.is_ident("bool") => ( + quote! { bool }, + quote! { #base_crate::ElicitResultContentValue::Boolean(b) }, + quote! { *b } + ), + Type::Path(type_path) if type_path.path.is_ident("i32") => ( + quote! { i32 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { + (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( + "Invalid number for field '{}': value {} does not fit in i32", + #field_name, *i + )))? + } + ), + Type::Path(type_path) if type_path.path.is_ident("i64") => ( + quote! { i64 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { *i } + ), + _ if is_enum(inner_type, &input) => { + let enum_parse = generate_enum_parse(inner_type, &field_name, &base_crate); + ( + quote! { #inner_type }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { #enum_parse } + ) + } + _ => panic!("Unsupported inner type for Option field: {}", quote! { #inner_type }), + }; + let inner_type_str = quote! { stringify!(#inner_type_ident) }; + quote! { + let #field_ident: Option<#inner_type_ident> = match content.as_ref().and_then(|map| map.get(#field_name)) { + Some(value) => { + match value { + #match_pattern => Some(#conversion), + _ => { + return Err(#base_crate::RpcError::parse_error().with_message(format!( + "Type mismatch for field '{}': expected {}, found {}", + #field_name, #inner_type_str, + match value { + #base_crate::ElicitResultContentValue::Boolean(_) => "boolean", + #base_crate::ElicitResultContentValue::String(_) => "string", + #base_crate::ElicitResultContentValue::Integer(_) => "integer", + } + ))); + } + } + } + None => None, + }; + } + } else { + // Determine the match arm based on the field type at compile time + let (field_type_ident, match_pattern, conversion) = match field_type { + Type::Path(type_path) if type_path.path.is_ident("String") => ( + quote! { String }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { s.clone() } + ), + Type::Path(type_path) if type_path.path.is_ident("bool") => ( + quote! { bool }, + quote! { #base_crate::ElicitResultContentValue::Boolean(b) }, + quote! { *b } + ), + Type::Path(type_path) if type_path.path.is_ident("i32") => ( + quote! { i32 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { + (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( + "Invalid number for field '{}': value {} does not fit in i32", + #field_name, *i + )))? + } + ), + Type::Path(type_path) if type_path.path.is_ident("i64") => ( + quote! { i64 }, + quote! { #base_crate::ElicitResultContentValue::Integer(i) }, + quote! { *i } + ), + _ if is_enum(field_type, &input) => { + let enum_parse = generate_enum_parse(field_type, &field_name, &base_crate); + ( + quote! { #field_type }, + quote! { #base_crate::ElicitResultContentValue::String(s) }, + quote! { #enum_parse } + ) + } + _ => panic!("Unsupported field type: {}", quote! { #field_type }), + }; + let type_str = quote! { stringify!(#field_type_ident) }; + quote! { + let #field_ident: #field_type_ident = match content.as_ref().and_then(|map| map.get(#field_name)) { + Some(value) => { + match value { + #match_pattern => #conversion, + _ => { + return Err(#base_crate::RpcError::parse_error().with_message(format!( + "Type mismatch for field '{}': expected {}, found {}", + #field_name, #type_str, + match value { + #base_crate::ElicitResultContentValue::Boolean(_) => "boolean", + #base_crate::ElicitResultContentValue::String(_) => "string", + #base_crate::ElicitResultContentValue::Integer(_) => "integer", + } + ))); + } + } + } + None => { + return Err(#base_crate::RpcError::parse_error().with_message(format!( + "Missing required field: {}", + #field_name + ))); + } + }; + } + }; + + type_check + }); + + let field_idents = fields.named.iter().map(|field| &field.ident); + + quote! { + #(#assignments)* + + Ok(Self { + #(#field_idents,)* + }) + } + } + _ => panic!("mcp_elicit macro only supports structs with named fields"), + }, + _ => panic!("mcp_elicit macro only supports structs"), + }; + + let output = quote! { + impl #input_ident { + + /// Returns the elicitation message defined in the `#[mcp_elicit(message = "...")]` attribute. + /// + /// This message is used to prompt the user or system for input when eliciting data for the struct. + /// If no message is provided in the attribute, an empty string is returned. + /// + /// # Returns + /// A `String` containing the elicitation message. + pub fn message()->String{ + #message.to_string() + } + + /// This method returns a `ElicitRequestedSchema` by retrieves the + /// struct's JSON schema (via the `JsonSchema` derive) and converting int into + /// a `ElicitRequestedSchema`. It extracts the `required` fields and + /// `properties` from the schema, mapping them to a `HashMap` of `PrimitiveSchemaDefinition` objects. + /// + /// # Returns + /// An `ElicitRequestedSchema` representing the schema of the struct. + /// + /// # Panics + /// Panics if the schema's properties cannot be converted to `PrimitiveSchemaDefinition` or if the schema + /// is malformed. + pub fn requested_schema() -> #base_crate::ElicitRequestedSchema { + let json_schema = &#input_ident::json_schema(); + + let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { + Some(arr) => arr + .iter() + .filter_map(|item| item.as_str().map(String::from)) + .collect(), + None => Vec::new(), + }; + + let properties: Option> = json_schema + .get("properties") + .and_then(|v| v.as_object()) // Safely extract "properties" as an object. + .map(|properties| { + properties + .iter() + .filter_map(|(key, value)| { + serde_json::to_value(value) + .ok() // If serialization fails, return None. + .and_then(|v| { + if let serde_json::Value::Object(obj) = v { + Some(obj) + } else { + None + } + }) + .map(|obj| (key.to_string(), #base_crate::PrimitiveSchemaDefinition::try_from(&obj))) + }) + .collect() + }); + + let properties = properties + .map(|map| { + map.into_iter() + .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple + .collect::, _>>() // collect only if all Ok + }) + .transpose() + .unwrap(); + + let properties = + properties.expect("Was not able to create a ElicitRequestedSchema"); + + let requested_schema = #base_crate::ElicitRequestedSchema::new(properties, required); + requested_schema + } + + /// Converts a map of field names and `ElicitResultContentValue` into an instance of the struct. + /// + /// This method parses the provided content map, matching field names to struct fields and converting + /// `ElicitResultContentValue` variants into the appropriate Rust types (e.g., `String`, `bool`, `i32`, + /// `i64`, or simple enums). It supports both required and optional fields (`Option`). + /// + /// # Parameters + /// - `content`: An optional `HashMap` mapping field names to `ElicitResultContentValue` values. + /// + /// # Returns + /// - `Ok(Self)` if the map is successfully parsed into the struct. + /// - `Err(RpcError)` if: + /// - A required field is missing. + /// - A value’s type does not match the expected field type. + /// - An integer value cannot be converted (e.g., `i64` to `i32` out of bounds). + /// - An enum value is invalid (e.g., string value does not match a enum variant name). + /// + /// # Errors + /// Returns `RpcError` with messages like: + /// - `"Missing required field: {}"` + /// - `"Type mismatch for field '{}': expected {}, found {}"` + /// - `"Invalid number for field '{}': value {} does not fit in i32"` + /// - `"Invalid enum value for field '{}': expected 'Yes' or 'No', found '{}'"`. + pub fn from_content_map(content: ::std::option::Option<::std::collections::HashMap<::std::string::String, #base_crate::ElicitResultContentValue>>) -> Result { + #field_assignments + } + } + #input + }; + + TokenStream::from(output) +} + /// Derives a JSON Schema representation for a struct. /// /// This procedural macro generates a `json_schema()` method for the annotated struct, returning a @@ -473,70 +832,222 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { /// # Dependencies /// Relies on `serde_json` for `Map` and `Value` types. /// -#[proc_macro_derive(JsonSchema)] +#[proc_macro_derive(JsonSchema, attributes(json_schema))] pub fn derive_json_schema(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); + let input = syn::parse_macro_input!(input as DeriveInput); let name = &input.ident; - let fields = match &input.data { + let schema_body = match &input.data { Data::Struct(data) => match &data.fields { - Fields::Named(fields) => &fields.named, - _ => panic!("JsonSchema derive macro only supports named fields"), + Fields::Named(fields) => { + let field_entries = fields.named.iter().map(|field| { + let field_attrs = &field.attrs; + let renamed_field = renamed_field(field_attrs); + let field_name = + renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + let field_type = &field.ty; + + let schema = type_to_json_schema(field_type, field_attrs); + quote! { + properties.insert( + #field_name.to_string(), + serde_json::Value::Object(#schema) + ); + } + }); + + let required_fields = fields.named.iter().filter_map(|field| { + let renamed_field = renamed_field(&field.attrs); + let field_name = + renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + + let field_type = &field.ty; + if !is_option(field_type) { + Some(quote! { + required.push(#field_name.to_string()); + }) + } else { + None + } + }); + + quote! { + let mut schema = serde_json::Map::new(); + let mut properties = serde_json::Map::new(); + let mut required = Vec::new(); + + #(#field_entries)* + + #(#required_fields)* + + schema.insert("type".to_string(), serde_json::Value::String("object".to_string())); + schema.insert("properties".to_string(), serde_json::Value::Object(properties)); + if !required.is_empty() { + schema.insert("required".to_string(), serde_json::Value::Array( + required.into_iter().map(serde_json::Value::String).collect() + )); + } + + schema + } + } + _ => panic!("JsonSchema derive macro only supports named fields for structs"), }, - _ => panic!("JsonSchema derive macro only supports structs"), - }; + Data::Enum(data) => { + let variant_schemas = data.variants.iter().map(|variant| { + let variant_attrs = &variant.attrs; + let variant_name = variant.ident.to_string(); + let renamed_variant = renamed_field(variant_attrs).unwrap_or(variant_name.clone()); - let field_entries = fields.iter().map(|field| { - let field_attrs = &field.attrs; - let renamed_field = renamed_field(field_attrs); - let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); - let field_type = &field.ty; + // Parse variant-level json_schema attributes + let mut title: Option = None; + let mut description: Option = None; + for attr in variant_attrs { + if attr.path().is_ident("json_schema") { + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("title") { + title = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("description") { + description = Some(meta.value()?.parse::()?.value()); + } + Ok(()) + }); + } + } - let schema = type_to_json_schema(field_type, field_attrs); - quote! { - properties.insert( - #field_name.to_string(), - serde_json::Value::Object(#schema) - ); - } - }); + let title_quote = title.as_ref().map(|t| { + quote! { map.insert("title".to_string(), serde_json::Value::String(#t.to_string())); } + }); + let description_quote = description.as_ref().map(|desc| { + quote! { map.insert("description".to_string(), serde_json::Value::String(#desc.to_string())); } + }); - let required_fields = fields.iter().filter_map(|field| { - let renamed_field = renamed_field(&field.attrs); - let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + match &variant.fields { + Fields::Unit => { + // Unit variant: use "enum" with the variant name + quote! { + { + let mut map = serde_json::Map::new(); + map.insert("enum".to_string(), serde_json::Value::Array(vec![ + serde_json::Value::String(#renamed_variant.to_string()) + ])); + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } + Fields::Unnamed(fields) => { + // Newtype or tuple variant + if fields.unnamed.len() == 1 { + // Newtype variant: use the inner type's schema + let field = &fields.unnamed[0]; + let field_type = &field.ty; + let field_attrs = &field.attrs; + let schema = type_to_json_schema(field_type, field_attrs); + quote! { + { + let mut map = #schema; + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } else { + // Tuple variant: array with items + let field_schemas = fields.unnamed.iter().map(|field| { + let field_type = &field.ty; + let field_attrs = &field.attrs; + let schema = type_to_json_schema(field_type, field_attrs); + quote! { serde_json::Value::Object(#schema) } + }); + quote! { + { + let mut map = serde_json::Map::new(); + map.insert("type".to_string(), serde_json::Value::String("array".to_string())); + map.insert("items".to_string(), serde_json::Value::Array(vec![#(#field_schemas),*])); + map.insert("additionalItems".to_string(), serde_json::Value::Bool(false)); + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } + } + Fields::Named(fields) => { + // Struct variant: object with properties and required fields + let field_entries = fields.named.iter().map(|field| { + let field_attrs = &field.attrs; + let renamed_field = renamed_field(field_attrs); + let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + let field_type = &field.ty; - let field_type = &field.ty; - if !is_option(field_type) { - Some(quote! { - required.push(#field_name.to_string()); - }) - } else { - None - } - }); + let schema = type_to_json_schema(field_type, field_attrs); + quote! { + properties.insert( + #field_name.to_string(), + serde_json::Value::Object(#schema) + ); + } + }); - let expanded = quote! { - impl #name { - pub fn json_schema() -> serde_json::Map { - let mut schema = serde_json::Map::new(); - let mut properties = serde_json::Map::new(); - let mut required = Vec::new(); + let required_fields = fields.named.iter().filter_map(|field| { + let renamed_field = renamed_field(&field.attrs); + let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string()); + + let field_type = &field.ty; + if !is_option(field_type) { + Some(quote! { + required.push(#field_name.to_string()); + }) + } else { + None + } + }); - #(#field_entries)* + quote! { + { + let mut map = serde_json::Map::new(); + let mut properties = serde_json::Map::new(); + let mut required = Vec::new(); - #(#required_fields)* + #(#field_entries)* - schema.insert("type".to_string(), serde_json::Value::String("object".to_string())); - schema.insert("properties".to_string(), serde_json::Value::Object(properties)); - if !required.is_empty() { - schema.insert("required".to_string(), serde_json::Value::Array( - required.into_iter().map(serde_json::Value::String).collect() - )); + #(#required_fields)* + + map.insert("type".to_string(), serde_json::Value::String("object".to_string())); + map.insert("properties".to_string(), serde_json::Value::Object(properties)); + if !required.is_empty() { + map.insert("required".to_string(), serde_json::Value::Array( + required.into_iter().map(serde_json::Value::String).collect() + )); + } + #title_quote + #description_quote + serde_json::Value::Object(map) + } + } + } } + }); + quote! { + let mut schema = serde_json::Map::new(); + schema.insert("oneOf".to_string(), serde_json::Value::Array(vec![ + #(#variant_schemas),* + ])); schema } } + _ => panic!("JsonSchema derive macro only supports structs and enums"), + }; + + let expanded = quote! { + impl #name { + pub fn json_schema() -> serde_json::Map { + #schema_body + } + } }; TokenStream::from(expanded) } diff --git a/crates/rust-mcp-macros/src/utils.rs b/crates/rust-mcp-macros/src/utils.rs index 0d4bbed..71d3de3 100644 --- a/crates/rust-mcp-macros/src/utils.rs +++ b/crates/rust-mcp-macros/src/utils.rs @@ -1,5 +1,8 @@ use quote::quote; -use syn::{punctuated::Punctuated, token, Attribute, Path, PathArguments, Type}; +use syn::{ + punctuated::Punctuated, token, Attribute, DeriveInput, Lit, LitInt, LitStr, Path, + PathArguments, Type, +}; // Check if a type is an Option pub fn is_option(ty: &Type) -> bool { @@ -13,8 +16,8 @@ pub fn is_option(ty: &Type) -> bool { false } -// Check if a type is a Vec #[allow(unused)] +// Check if a type is a Vec pub fn is_vec(ty: &Type) -> bool { if let Type::Path(type_path) = ty { if type_path.path.segments.len() == 1 { @@ -26,8 +29,8 @@ pub fn is_vec(ty: &Type) -> bool { false } -// Extract the inner type from Vec or Option #[allow(unused)] +// Extract the inner type from Vec or Option pub fn inner_type(ty: &Type) -> Option<&Type> { if let Type::Path(type_path) = ty { if type_path.path.segments.len() == 1 { @@ -46,12 +49,11 @@ pub fn inner_type(ty: &Type) -> Option<&Type> { None } -fn doc_comment(attrs: &[Attribute]) -> Option { +pub fn doc_comment(attrs: &[Attribute]) -> Option { let mut docs = Vec::new(); for attr in attrs { if attr.path().is_ident("doc") { if let syn::Meta::NameValue(meta) = &attr.meta { - // Match value as Expr::Lit, then extract Lit::Str if let syn::Expr::Lit(expr_lit) = &meta.value { if let syn::Lit::Str(lit_str) = &expr_lit.lit { docs.push(lit_str.value().trim().to_string()); @@ -82,16 +84,143 @@ pub fn might_be_struct(ty: &Type) -> bool { false } +// Helper to check if a type is an enum +pub fn is_enum(ty: &Type, _input: &DeriveInput) -> bool { + if let Type::Path(type_path) = ty { + // Check for #[mcp_elicit(enum)] attribute on the type + // Since we can't access the enum's definition directly, we rely on the attribute + // This assumes the enum is marked with #[mcp_elicit(enum)] in its definition + // Alternatively, we could pass a list of known enums, but attribute-based is simpler + type_path + .path + .segments + .last() + .map(|s| { + // For now, we'll assume any type could be an enum if it has the attribute + // In a real-world scenario, we'd need to resolve the type's definition + // For simplicity, we check if the type name is plausible (not String, bool, i32, i64) + let ident = s.ident.to_string(); + !["String", "bool", "i32", "i64"].contains(&ident.as_str()) + }) + .unwrap_or(false) + } else { + false + } +} + +// Helper to generate enum parsing code +pub fn generate_enum_parse( + field_type: &Type, + field_name: &str, + base_crate: &proc_macro2::TokenStream, +) -> proc_macro2::TokenStream { + let type_ident = match field_type { + Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(), + _ => panic!("Expected path type for enum"), + }; + // Since we can't access the enum's variants directly in this context, + // we'll assume the enum has unit variants and expect strings matching their names + // In a real-world scenario, you'd parse the enum's Data::Enum to get variant names + // For now, we'll generate a generic parse assuming variant names are provided as strings + quote! { + { + // Attempt to parse the string using a match + // Since we don't have the variants, we rely on the enum implementing FromStr + match s.as_str() { + // We can't dynamically list variants, so we use FromStr + // If FromStr is not implemented, this will fail at compile time + s => s.parse().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!( + "Invalid enum value for field '{}': cannot parse '{}' into {}", + #field_name, s, stringify!(#type_ident) + )))? + } + } + } +} + pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::TokenStream { - let number_types = [ - "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", "f32", "f64", + let integer_types = [ + "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", ]; - let doc_comment = doc_comment(attrs); - let description = doc_comment.as_ref().map(|desc| { + let float_types = ["f32", "f64"]; + + // Parse custom json_schema attributes + let mut title: Option = None; + let mut format: Option = None; + let mut min_length: Option = None; + let mut max_length: Option = None; + let mut minimum: Option = None; + let mut maximum: Option = None; + let mut default: Option = None; + let mut attr_description: Option = None; + + for attr in attrs { + if attr.path().is_ident("json_schema") { + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("title") { + title = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("description") { + attr_description = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("format") { + format = Some(meta.value()?.parse::()?.value()); + } else if meta.path.is_ident("min_length") { + min_length = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("max_length") { + max_length = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("minimum") { + minimum = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("maximum") { + maximum = Some(meta.value()?.parse::()?.base10_parse::()?); + } else if meta.path.is_ident("default") { + let lit = meta.value()?.parse::()?; + default = Some(match lit { + Lit::Str(lit_str) => { + let value = lit_str.value(); + quote! { serde_json::Value::String(#value.to_string()) } + } + Lit::Int(lit_int) => { + let value = lit_int.base10_parse::()?; + assert!( + (i64::MIN..=i64::MAX).contains(&value), + "Default value {value} out of range for i64" + ); + quote! { serde_json::Value::Number(serde_json::Number::from(#value)) } + } + Lit::Float(lit_float) => { + let value = lit_float.base10_parse::()?; + quote! { serde_json::Value::Number(serde_json::Number::from_f64(#value).expect("Invalid float")) } + } + Lit::Bool(lit_bool) => { + let value = lit_bool.value(); + quote! { serde_json::Value::Bool(#value) } + } + _ => return Err(meta.error("Unsupported default value type")), + }); + } + Ok(()) + }); + } + } + + let description = attr_description.or(doc_comment(attrs)); + let description_quote = description.as_ref().map(|desc| { quote! { map.insert("description".to_string(), serde_json::Value::String(#desc.to_string())); } }); + + let title_quote = title.as_ref().map(|t| { + quote! { + map.insert("title".to_string(), serde_json::Value::String(#t.to_string())); + } + }); + + let default_quote = default.as_ref().map(|d| { + quote! { + map.insert("default".to_string(), #d); + } + }); + match ty { Type::Path(type_path) => { if type_path.path.segments.len() == 1 { @@ -104,15 +233,43 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token if args.args.len() == 1 { if let syn::GenericArgument::Type(inner_ty) = &args.args[0] { let inner_schema = type_to_json_schema(inner_ty, attrs); + let format_quote = format.as_ref().map(|f| { + quote! { + map.insert("format".to_string(), serde_json::Value::String(#f.to_string())); + } + }); + let min_quote = min_length.as_ref().map(|min| { + quote! { + map.insert("minLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = max_length.as_ref().map(|max| { + quote! { + map.insert("maxLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); + let min_num_quote = minimum.as_ref().map(|min| { + quote! { + map.insert("minimum".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_num_quote = maximum.as_ref().map(|max| { + quote! { + map.insert("maximum".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { - let mut map = serde_json::Map::new(); - let inner_map = #inner_schema; - for (k, v) in inner_map { - map.insert(k, v); - } + let mut map = #inner_schema; map.insert("nullable".to_string(), serde_json::Value::Bool(true)); - #description + #description_quote + #title_quote + #format_quote + #min_quote + #max_quote + #min_num_quote + #max_num_quote + #default_quote map } }; @@ -126,12 +283,26 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token if args.args.len() == 1 { if let syn::GenericArgument::Type(inner_ty) = &args.args[0] { let inner_schema = type_to_json_schema(inner_ty, &[]); + let min_quote = min_length.as_ref().map(|min| { + quote! { + map.insert("minItems".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = max_length.as_ref().map(|max| { + quote! { + map.insert("maxItems".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("array".to_string())); map.insert("items".to_string(), serde_json::Value::Object(#inner_schema)); - #description + #description_quote + #title_quote + #min_quote + #max_quote + #default_quote map } }; @@ -144,36 +315,104 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token let path = &type_path.path; return quote! { { - let inner_schema = #path::json_schema(); - inner_schema + let mut map = #path::json_schema(); + #description_quote + #title_quote + #default_quote + map } }; } - // Handle basic types + // Handle String else if ident == "String" { + let format_quote = format.as_ref().map(|f| { + quote! { + map.insert("format".to_string(), serde_json::Value::String(#f.to_string())); + } + }); + let min_quote = min_length.as_ref().map(|min| { + quote! { + map.insert("minLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = max_length.as_ref().map(|max| { + quote! { + map.insert("maxLength".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("string".to_string())); - #description + #description_quote + #title_quote + #format_quote + #min_quote + #max_quote + #default_quote map } }; - } else if number_types.iter().any(|t| ident == t) { + } + // Handle integer types + else if integer_types.iter().any(|t| ident == t) { + let min_quote = minimum.as_ref().map(|min| { + quote! { + map.insert("minimum".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = maximum.as_ref().map(|max| { + quote! { + map.insert("maximum".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); + return quote! { + { + let mut map = serde_json::Map::new(); + map.insert("type".to_string(), serde_json::Value::String("integer".to_string())); + #description_quote + #title_quote + #min_quote + #max_quote + #default_quote + map + } + }; + } + // Handle float types + else if float_types.iter().any(|t| ident == t) { + let min_quote = minimum.as_ref().map(|min| { + quote! { + map.insert("minimum".to_string(), serde_json::Value::Number(serde_json::Number::from(#min))); + } + }); + let max_quote = maximum.as_ref().map(|max| { + quote! { + map.insert("maximum".to_string(), serde_json::Value::Number(serde_json::Number::from(#max))); + } + }); return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("number".to_string())); - #description + #description_quote + #title_quote + #min_quote + #max_quote + #default_quote map } }; - } else if ident == "bool" { + } + // Handle bool + else if ident == "bool" { return quote! { { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("boolean".to_string())); - #description + #description_quote + #title_quote + #default_quote map } }; @@ -184,7 +423,9 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("unknown".to_string())); - #description + #description_quote + #title_quote + #default_quote map } } @@ -193,7 +434,9 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token { let mut map = serde_json::Map::new(); map.insert("type".to_string(), serde_json::Value::String("unknown".to_string())); - #description + #description_quote + #title_quote + #default_quote map } }, @@ -204,7 +447,6 @@ pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::Token pub fn has_derive(attrs: &[Attribute], trait_name: &str) -> bool { attrs.iter().any(|attr| { if attr.path().is_ident("derive") { - // Parse the derive arguments as a comma-separated list of paths let parsed = attr.parse_args_with(Punctuated::::parse_terminated); if let Ok(derive_paths) = parsed { let derived = derive_paths.iter().any(|path| path.is_ident(trait_name)); @@ -220,7 +462,6 @@ pub fn renamed_field(attrs: &[Attribute]) -> Option { for attr in attrs { if attr.path().is_ident("serde") { - // Ignore other serde meta items (e.g., skip_serializing_if) let _ = attr.parse_nested_meta(|meta| { if meta.path.is_ident("rename") { if let Ok(lit) = meta.value() { @@ -493,12 +734,12 @@ mod tests { } #[test] - fn test_json_schema_number() { + fn test_json_schema_integer() { let ty: syn::Type = parse_quote!(i32); let tokens = type_to_json_schema(&ty, &[]); let output = render(tokens); assert!(output - .contains("\"type\".to_string(),serde_json::Value::String(\"number\".to_string())")); + .contains("\"type\".to_string(),serde_json::Value::String(\"integer\".to_string())")); } #[test] @@ -527,7 +768,7 @@ mod tests { let output = render(tokens); assert!(output.contains("\"nullable\".to_string(),serde_json::Value::Bool(true)")); assert!(output - .contains("\"type\".to_string(),serde_json::Value::String(\"number\".to_string())")); + .contains("\"type\".to_string(),serde_json::Value::String(\"integer\".to_string())")); } #[test] diff --git a/crates/rust-mcp-macros/tests/common/common.rs b/crates/rust-mcp-macros/tests/common/common.rs index 40c4e3c..d6bae2e 100644 --- a/crates/rust-mcp-macros/tests/common/common.rs +++ b/crates/rust-mcp-macros/tests/common/common.rs @@ -1,4 +1,7 @@ +use std::str::FromStr; + use rust_mcp_macros::JsonSchema; +use rust_mcp_schema::RpcError; #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] /// Represents a text replacement operation. @@ -26,3 +29,50 @@ pub struct EditFileTool { )] pub dry_run: Option, } + +#[derive(JsonSchema, Debug)] +pub enum Colors { + #[json_schema(title = "Green Color")] + Green, + #[json_schema(title = "Red Color")] + Red, +} + +impl FromStr for Colors { + type Err = RpcError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "green" => Ok(Colors::Green), + "red" => Ok(Colors::Red), + _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), + } + } +} + +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + + /// Email address of the user + #[json_schema(title = "Email", format = "email")] + pub email: Option, + + /// The user's age in years + #[json_schema(title = "Age", minimum = 15, maximum = 125)] + pub age: i32, + + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} diff --git a/crates/rust-mcp-macros/tests/macro_test.rs b/crates/rust-mcp-macros/tests/macro_test.rs index 3a23c87..4b6c926 100644 --- a/crates/rust-mcp-macros/tests/macro_test.rs +++ b/crates/rust-mcp-macros/tests/macro_test.rs @@ -1,4 +1,16 @@ +#[macro_use] +extern crate rust_mcp_macros; + +use std::collections::HashMap; + use common::EditOperation; +use rust_mcp_schema::{ + BooleanSchema, ElicitRequestedSchema, ElicitResultContentValue, EnumSchema, NumberSchema, + PrimitiveSchemaDefinition, StringSchema, StringSchemaFormat, +}; +use serde_json::json; + +use crate::common::{Colors, UserInfo}; #[path = "common/common.rs"] pub mod common; @@ -31,3 +43,232 @@ fn test_rename() { let properties = schema.get("properties").unwrap().as_object().unwrap(); assert_eq!(properties.len(), 2); } + +#[test] +fn test_attributes() { + #[derive(JsonSchema)] + struct User { + /// This is a fallback description from doc comment. + pub id: i32, + + #[json_schema( + title = "User Name", + description = "The user's full name (overrides doc)", + min_length = 1, + max_length = 100 + )] + pub name: String, + + #[json_schema( + title = "User Email", + format = "email", + min_length = 5, + max_length = 255 + )] + pub email: Option, + + #[json_schema( + title = "Tags", + description = "List of tags", + min_length = 0, + max_length = 10 + )] + pub tags: Vec, + } + + let schema = User::json_schema(); + let expected = json!({ + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "This is a fallback description from doc comment." + }, + "name": { + "type": "string", + "title": "User Name", + "description": "The user's full name (overrides doc)", + "minLength": 1, + "maxLength": 100 + }, + "email": { + "type": "string", + "title": "User Email", + "format": "email", + "minLength": 5, + "maxLength": 255, + "nullable": true + }, + "tags": { + "type": "array", + "items": { + "type": "string", + }, + "title": "Tags", + "description": "List of tags", + "minItems": 0, + "maxItems": 10 + } + }, + "required": ["id", "name", "tags"] + }); + + // Convert expected_value from serde_json::Value to serde_json::Map + let expected: serde_json::Map = + expected.as_object().expect("Expected JSON object").clone(); + + assert_eq!(schema, expected); +} + +#[test] +fn test_elicit_macro() { + assert_eq!(UserInfo::message(), "Please enter your info"); + + let requested_schema: ElicitRequestedSchema = UserInfo::requested_schema(); + assert_eq!( + requested_schema.required, + vec!["name", "age", "favorate_color"] + ); + + assert!(matches!( + requested_schema.properties.get("is_student").unwrap(), + PrimitiveSchemaDefinition::BooleanSchema(BooleanSchema { + default, + description, + title, + .. + }) + if + description.as_ref().unwrap() == "Is user a student?" && + title.as_ref().unwrap() == "Is student?" && + matches!(default, Some(true)) + + )); + + assert!(matches!( + requested_schema.properties.get("favorate_color").unwrap(), + PrimitiveSchemaDefinition::EnumSchema(EnumSchema { + description, + enum_, + enum_names, + title, + .. + }) + if description.as_ref().unwrap() == "User's favorite color" && + title.is_none() && + enum_.len()==2 && enum_.iter().all(|s| ["Green", "Red"].contains(&s.as_str())) && + enum_names.len()==2 && enum_names.iter().all(|s| ["Green Color", "Red Color"].contains(&s.as_str())) + )); + + assert!(matches!( + requested_schema.properties.get("age").unwrap(), + PrimitiveSchemaDefinition::NumberSchema(NumberSchema { + description, + maximum, + minimum, + title, + type_ + }) + if + description.as_ref().unwrap() == "The user's age in years" && + maximum.unwrap() == 125 && minimum.unwrap() == 15 && title.as_ref().unwrap() == "Age" + )); + + assert!(matches!( + requested_schema.properties.get("name").unwrap(), + PrimitiveSchemaDefinition::StringSchema(StringSchema { + description, + format, + max_length, + min_length, + title, + .. + }) + if format.is_none() && + description.as_ref().unwrap() == "The user's full name" && + max_length.unwrap() == 100 && min_length.unwrap() == 5 && title.as_ref().unwrap() == "Name" + )); + + assert!(matches!( + requested_schema.properties.get("email").unwrap(), + PrimitiveSchemaDefinition::StringSchema(StringSchema { + description, + format, + max_length, + min_length, + title, + .. + }) if matches!(format.unwrap(), StringSchemaFormat::Email) && + description.as_ref().unwrap() == "Email address of the user" && + max_length.is_none() && min_length.is_none() && title.as_ref().unwrap() == "Email" + )); + + let json_schema = &UserInfo::json_schema(); + + let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { + Some(arr) => arr + .iter() + .filter_map(|item| item.as_str().map(String::from)) + .collect(), + None => Vec::new(), + }; + + let properties: Option> = json_schema + .get("properties") + .and_then(|v| v.as_object()) // Safely extract "properties" as an object. + .map(|properties| { + properties + .iter() + .filter_map(|(key, value)| { + serde_json::to_value(value) + .ok() // If serialization fails, return None. + .and_then(|v| { + if let serde_json::Value::Object(obj) = v { + Some(obj) + } else { + None + } + }) + .map(|obj| (key.to_string(), PrimitiveSchemaDefinition::try_from(&obj))) + }) + .collect() + }); + + let properties = properties + .map(|map| { + map.into_iter() + .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple + .collect::, _>>() // collect only if all Ok + }) + .transpose() + .unwrap(); + + let properties = properties.expect("Was not able to create a ElicitRequestedSchema"); + + ElicitRequestedSchema::new(properties, required); +} + +#[test] +fn test_from_content_map() { + let mut content: ::std::collections::HashMap<::std::string::String, ElicitResultContentValue> = + HashMap::new(); + + content.extend([ + ( + "name".to_string(), + ElicitResultContentValue::String("Ali".to_string()), + ), + ( + "favorate_color".to_string(), + ElicitResultContentValue::String("Green".to_string()), + ), + ("age".to_string(), ElicitResultContentValue::Integer(15)), + ( + "is_student".to_string(), + ElicitResultContentValue::Boolean(false), + ), + ]); + + let u: UserInfo = UserInfo::from_content_map(Some(content)).unwrap(); + assert!(matches!(u.favorate_color, Colors::Green)); +} diff --git a/crates/rust-mcp-sdk/CHANGELOG.md b/crates/rust-mcp-sdk/CHANGELOG.md index 8f2f4f7..4fde908 100644 --- a/crates/rust-mcp-sdk/CHANGELOG.md +++ b/crates/rust-mcp-sdk/CHANGELOG.md @@ -1,5 +1,68 @@ # Changelog +## [0.7.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.3...rust-mcp-sdk-v0.7.0) (2025-09-19) + + +### ⚠ BREAKING CHANGES + +* add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) +* update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) + +### πŸš€ Features + +* Add elicitation macros and add elicit_input() method ([#99](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/99)) ([3ab5fe7](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/3ab5fe73aaa10de2b5b23caee357ac15b37c845f)) +* Add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) ([abb0c36](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/abb0c36126b0a397bc20a1de36c5a5a80924a01e)) +* Add tls-no-provider feature ([#97](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/97)) ([5dacceb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/5dacceb0c2d18b8334744a13d438c6916bb7244c)) +* Event store support for resumability ([#101](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/101)) ([08742bb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/08742bb9636f81ee79eda4edc192b3b8ed4c7287)) +* Update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) ([a2d6d23](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/a2d6d23ab59fbc34d04526e2606f747f93a8468c)) + +## [0.6.3](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.2...rust-mcp-sdk-v0.6.3) (2025-08-31) + +## [0.6.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.1...rust-mcp-sdk-v0.6.2) (2025-08-30) + + +### πŸ› Bug Fixes + +* Tool-box macro panic on invalid requests ([#92](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/92)) ([54cc8ed](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/54cc8edb55c41455dd9211f296560e7a792a7b9c)) + +## [0.6.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.6.0...rust-mcp-sdk-v0.6.1) (2025-08-28) + + +### πŸ› Bug Fixes + +* Session ID access in handlers and add helper for listing active ([#90](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/90)) ([f2f0afb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/f2f0afb542f6ff036a28cf01e102b27ce940665b)) + +## [0.6.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.3...rust-mcp-sdk-v0.6.0) (2025-08-19) + + +### ⚠ BREAKING CHANGES + +* improve request ID generation, remove deprecated methods and adding improvements + +### πŸš€ Features + +* Improve request ID generation, remove deprecated methods and adding improvements ([95b91aa](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/95b91aad191e1b8777ca4a02612ab9183e0276d3)) + +## [0.5.3](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.2...rust-mcp-sdk-v0.5.3) (2025-08-19) + + +### πŸ› Bug Fixes + +* Handle missing client details and abort keep-alive task on drop ([#83](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/83)) ([308b1db](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/308b1dbd1744ff06046902303d8bcd6c3a92ffbe)) + +## [0.5.2](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.1...rust-mcp-sdk-v0.5.2) (2025-08-16) + + +### πŸš€ Features + +* Integrate list root and client info into hyper runtime ([36dfa4c](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/36dfa4cdc821e958ffe78b909ed28f5577d113c8)) + + +### πŸ› Bug Fixes + +* Abort keep-alive task when transport is removed ([#82](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/82)) ([1ca8e49](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/1ca8e49860e990c3562623e75dd723b0d1dc8256)) +* Ensure server-initiated requests include a valid request_id ([#80](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/80)) ([5f9a966](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/5f9a966bb523bf61daefcff209199bc774fa5ed6)) + ## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-sdk-v0.5.0...rust-mcp-sdk-v0.5.1) (2025-08-12) diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 5f28fa3..8bba7c7 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-sdk" -version = "0.5.1" +version = "0.7.0" authors = ["Ali Hashemi"] categories = ["data-structures", "parser-implementations", "parsing"] description = "An asynchronous SDK and framework for building MCP-Servers and MCP-Clients, leveraging the rust-mcp-schema for type safe MCP Schema Objects." @@ -24,15 +24,17 @@ futures = { workspace = true } thiserror = { workspace = true } axum = { workspace = true, optional = true } -uuid = { workspace = true, features = ["v4"], optional = true } +uuid = { workspace = true, features = ["v4"] } tokio-stream = { workspace = true, optional = true } axum-server = { version = "0.7", features = [], optional = true } tracing.workspace = true +base64.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } [dev-dependencies] +wiremock = "0.5" reqwest = { workspace = true, default-features = false, features = [ "stream", "rustls-tls", @@ -51,47 +53,55 @@ default = [ "client", "server", "macros", + "stdio", + "sse", + "streamable-http", "hyper-server", "ssl", "2025_06_18", ] # All features enabled by default -server = ["rust-mcp-transport/stdio"] # Server feature -client = ["rust-mcp-transport/stdio", "rust-mcp-transport/sse"] # Client feature -hyper-server = [ - "axum", - "axum-server", - "hyper", - "server", - "uuid", - "tokio-stream", - "rust-mcp-transport/sse", -] + +sse = ["rust-mcp-transport/sse"] +streamable-http = ["rust-mcp-transport/streamable-http"] +stdio = ["rust-mcp-transport/stdio"] + +server = [] # Server feature +client = [] # Client feature +hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream"] ssl = ["axum-server/tls-rustls"] +tls-no-provider = ["axum-server/tls-rustls-no-provider"] macros = ["rust-mcp-macros/sdk"] -# enables mcp protocol version 2025_06_18 -2025_06_18 = [ +# enables mcp protocol version 2025-06-18 +2025-06-18 = [ "rust-mcp-schema/2025_06_18", "rust-mcp-macros/2025_06_18", "rust-mcp-transport/2025_06_18", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_06_18 = ["2025-06-18"] # enables mcp protocol version 2025_03_26 -2025_03_26 = [ +2025-03-26 = [ "rust-mcp-schema/2025_03_26", "rust-mcp-macros/2025_03_26", "rust-mcp-transport/2025_03_26", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_03_26 = ["2025-03-26"] + # enables mcp protocol version 2024_11_05 -2024_11_05 = [ +2024-11-05 = [ "rust-mcp-schema/2024_11_05", "rust-mcp-macros/2024_11_05", "rust-mcp-transport/2024_11_05", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2024_11_05 = ["2024-11-05"] [lints] workspace = true diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index ef5b4ed..2c70c3e 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -32,15 +32,14 @@ This project supports following transports: πŸš€ The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - βœ… Streamable HTTP Support for MCP Servers - βœ… DNS Rebinding Protection - βœ… Batch Messages - βœ… Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients -- ⬜ Resumability -- ⬜ Authentication / Oauth +- βœ… Streamable HTTP Support for MCP Clients +- βœ… Resumability +- ⬜ Oauth Authentication **⚠️** Project is currently under development and should be used at your own risk. @@ -49,7 +48,9 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) +- [Macros](#macros) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) - [Security Considerations](#security-considerations) @@ -110,7 +111,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -153,6 +154,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); @@ -180,7 +182,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +193,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +207,7 @@ impl ServerHandler for MyServerHandler { --- -πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +πŸ‘‰ For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -283,6 +285,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -294,8 +298,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +πŸ‘‰ see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -306,6 +384,116 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +πŸ‘‰ see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + + +## Macros +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. + +> To use these macros, ensure the `macros` feature is enabled in your Cargo.toml. + +### mcp_tool +`mcp_tool` is a procedural macro attribute that helps generating rust_mcp_schema::Tool from a struct. + +Usage example: +```rust +#[mcp_tool( + name = "move_file", + title="Move File", + description = concat!("Move or rename files and directories. Can move files between directories ", +"and rename them in a single operation. If the destination exists, the ", +"operation will fail. Works across different directories and can be used ", +"for simple renaming within the same directory. ", +"Both source and destination must be within allowed directories."), + destructive_hint = false, + idempotent_hint = false, + open_world_hint = false, + read_only_hint = false +)] +#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] +pub struct MoveFileTool { + /// The source path of the file to move. + pub source: String, + /// The destination path to move the file to. + pub destination: String, +} + +// Now we can call `tool()` method on it to get a Tool instance +let rust_mcp_sdk::schema::Tool = MoveFileTool::tool(); + +``` + +πŸ’» For a real-world example, check out any of the tools available at: https://github.com/rust-mcp-stack/rust-mcp-filesystem/tree/main/src/tools + + +### tool_box +`tool_box` generates an enum from a provided list of tools, making it easier to organize and manage them, especially when your application includes a large number of tools. + +It accepts an array of tools and generates an enum where each tool becomes a variant of the enum. + +Generated enum has a `tools()` function that returns a `Vec` , and a `TryFrom` trait implementation that could be used to convert a ToolRequest into a Tool instance. + +Usage example: +```rust + // Accepts an array of tools and generates an enum named `FileSystemTools`, + // where each tool becomes a variant of the enum. + tool_box!(FileSystemTools, [ReadFileTool, MoveFileTool, SearchFilesTool]); + + // now in the app, we can use the FileSystemTools, like: + let all_tools: Vec = FileSystemTools::tools(); +``` + +πŸ’» To see a real-world example of that please see : +- `tool_box` macro usage: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/tools.rs) +- using `tools()` in list tools request : [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L67) +- using `try_from` in call tool_request: [https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-filesystem/blob/main/src/handler.rs#L100) + + + +### mcp_elicit +The `mcp_elicit` macro generates implementations for the annotated struct to facilitate data elicitation. It enables struct to generate `ElicitRequestedSchema` and also parsing a map of field names to `ElicitResultContentValue` values back into the struct, supporting both required and optional fields. The generated implementation includes: + +- A `message()` method returning the elicitation message as a string. +- A `requested_schema()` method returning an `ElicitRequestedSchema` based on the struct’s JSON schema. +- A `from_content_map()` method to convert a map of `ElicitResultContentValue` values into a struct instance. + +### Attributes + +- `message` - An optional string (or `concat!(...)` expression) to prompt the user or system for input. Defaults to an empty string if not provided. + +Usage example: +```rust +// A struct that could be used to send elicit request and get the input from the user +#[mcp_elicit(message = "Please enter your info")] +#[derive(JsonSchema)] +pub struct UserInfo { + #[json_schema( + title = "Name", + description = "The user's full name", + min_length = 5, + max_length = 100 + )] + pub name: String, + /// Is user a student? + #[json_schema(title = "Is student?", default = true)] + pub is_student: Option, + + /// User's favorite color + pub favorate_color: Colors, +} + +// send a Elicit Request , ask for UserInfo data and convert the result back to a valid UserInfo instance +let result: ElicitResult = server + .elicit_input(UserInfo::message(), UserInfo::requested_schema()) + .await?; + +// Create a UserInfo instance using data provided by the user on the client side +let user_info = UserInfo::from_content_map(result.content)?; + +``` + +πŸ’» For mre info please see : +- https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros ## Getting Started @@ -337,6 +525,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -344,9 +533,19 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -356,12 +555,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -373,17 +566,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -395,6 +577,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -416,9 +609,15 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. + +- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. + #### MCP Protocol Versions with Corresponding Features @@ -449,9 +648,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -464,7 +663,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` @@ -477,10 +676,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +708,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK @@ -526,6 +725,11 @@ Below is a list of projects that utilize the `rust-mcp-sdk`, showcasing their na | | [text-to-cypher](https://github.com/FalkorDB/text-to-cypher) | A high-performance Rust-based API service that translates natural language text to Cypher queries for graph databases. | [GitHub](https://github.com/FalkorDB/text-to-cypher) | | | [notify-mcp](https://github.com/Tuurlijk/notify-mcp) | A Model Context Protocol (MCP) server that provides desktop notification functionality. | [GitHub](https://github.com/Tuurlijk/notify-mcp) | | | [lst](https://github.com/WismutHansen/lst) | `lst` is a personal lists, notes, and blog posts management application with a focus on plain-text storage, offline-first functionality, and multi-device synchronization. | [GitHub](https://github.com/WismutHansen/lst) | +| | [rust-mcp-server](https://github.com/Vaiz/rust-mcp-server) | `rust-mcp-server` allows the model to perform actions on your behalf, such as building, testing, and analyzing your Rust code. | [GitHub](https://github.com/Vaiz/rust-mcp-server) | + + + + diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 2feab67..3879526 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -11,25 +11,36 @@ pub type SdkResult = core::result::Result; #[derive(Debug, Error)] pub enum McpSdkError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] RpcError(#[from] RpcError), + #[error("{0}")] - IoError(#[from] std::io::Error), - #[error("{0}")] - TransportError(#[from] TransportError), - #[error("{0}")] - JoinError(#[from] JoinError), - #[error("{0}")] - AnyError(Box<(dyn std::error::Error + Send + Sync)>), - #[error("{0}")] - SdkError(#[from] crate::schema::schema_utils::SdkError), + Join(#[from] JoinError), + #[cfg(feature = "hyper-server")] #[error("{0}")] - TransportServerError(#[from] TransportServerError), - #[error("Incompatible mcp protocol version: requested:{0} current:{1}")] - IncompatibleProtocolVersion(String, String), + HyperServer(#[from] TransportServerError), + #[error("{0}")] - ParseProtocolVersionError(#[from] ParseProtocolVersionError), + SdkError(#[from] crate::schema::schema_utils::SdkError), + + #[error("Protocol error: {kind}")] + Protocol { kind: ProtocolErrorKind }, +} + +// Sub-enum for protocol-related errors +#[derive(Debug, Error)] +pub enum ProtocolErrorKind { + #[error("Incompatible protocol version: requested {requested}, current {current}")] + IncompatibleVersion { requested: String, current: String }, + #[error("Failed to parse protocol version: {0}")] + ParseError(#[from] ParseProtocolVersionError), } impl McpSdkError { @@ -41,6 +52,3 @@ impl McpSdkError { None } } - -#[deprecated(since = "0.2.0", note = "Use `McpSdkError` instead.")] -pub type MCPSdkError = McpSdkError; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index 0c1dcf3..f96b261 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -1,11 +1,12 @@ use std::{sync::Arc, time::Duration}; -use crate::schema::InitializeResult; -use rust_mcp_transport::TransportOptions; - +use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; +use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; + +use rust_mcp_transport::event_store::EventStore; -use super::{session_store::SessionStore, IdGenerator}; +use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server /// @@ -14,7 +15,8 @@ use super::{session_store::SessionStore, IdGenerator}; #[derive(Clone)] pub struct AppState { pub session_store: Arc, - pub id_generator: Arc, + pub id_generator: Arc>, + pub stream_id_gen: Arc, pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, @@ -31,6 +33,9 @@ pub struct AppState { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, } impl AppState { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 30df951..85cf791 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -4,7 +4,8 @@ use crate::{ mcp_server::HyperServer, schema::{ schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, - CreateMessageRequestParams, CreateMessageResult, LoggingMessageNotificationParams, + CreateMessageRequestParams, CreateMessageResult, InitializeRequestParams, + ListRootsRequestParams, ListRootsResult, LoggingMessageNotificationParams, PromptListChangedNotificationParams, ResourceListChangedNotificationParams, ResourceUpdatedNotificationParams, ToolListChangedNotificationParams, }, @@ -69,6 +70,12 @@ impl HyperRuntime { result.map_err(|err| err.into()) } + /// Returns a list of active session IDs from the session store. + pub async fn sessions(&self) -> Vec { + self.state.session_store.keys().await + } + + /// Retrieves the runtime associated with the given session ID from the session store. pub async fn runtime_by_session( &self, session_id: &SessionId, @@ -99,6 +106,21 @@ impl HyperRuntime { runtime.send_notification(notification).await } + /// Request a list of root URIs from the client. Roots allow + /// servers to ask for specific directories or files to operate on. A common example + /// for roots is providing a set of repositories or directories a server should operate on. + /// This request is typically used when the server needs to understand the file system + /// structure or access specific locations that the client has permission to read from + pub async fn list_roots( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.list_roots(params).await + } + pub async fn send_logging_message( &self, session_id: &SessionId, @@ -195,4 +217,13 @@ impl HyperRuntime { let runtime = runtime.lock().await.to_owned(); runtime.create_message(params).await } + + pub async fn client_info( + &self, + session_id: &SessionId, + ) -> SdkResult> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + Ok(runtime.client_info()) + } } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 79bf226..7101a73 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -6,7 +6,7 @@ use crate::{ }, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::mcp_handler::McpServerHandler, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, utils::validate_mcp_protocol_version, }; @@ -22,13 +22,13 @@ use axum::{ }; use futures::stream; use hyper::{header, HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, SseTransport}; +use rust_mcp_transport::{ + EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, +}; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; -pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; -pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; - const DUPLEX_BUFFER_SIZE: usize = 8192; async fn create_sse_stream( @@ -37,15 +37,16 @@ async fn create_sse_stream( state: Arc, payload: Option<&str>, standalone: bool, + last_event_id: Option, ) -> TransportServerResult> { let payload_string = payload.map(|p| p.to_string()); // TODO: this logic should be moved out after refactoing the mcp_stream.rs - let result = payload_string + let payload_contains_request = payload_string .as_ref() .map(|json_str| contains_request(json_str)) .unwrap_or(Ok(false)); - let Ok(payload_contains_request) = result else { + let Ok(payload_contains_request) = payload_contains_request else { return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); }; @@ -54,47 +55,85 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = SseTransport::::new( + let session_id = Arc::new(session_id); + let stream_id: Arc = if standalone { + Arc::new(DEFAULT_STREAM_ID.to_string()) + } else { + Arc::new(state.stream_id_gen.generate()) + }; + + let event_store = state.event_store.as_ref().map(Arc::clone); + let resumability_enabled = event_store.is_some(); + + let mut transport = SseTransport::::new( read_rx, write_tx, read_tx, Arc::clone(&state.transport_options), ) .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + if let Some(event_store) = event_store.clone() { + transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); + } + let transport = Arc::new(transport); - let stream_id = if standalone { - DEFAULT_STREAM_ID.to_string() - } else { - state.id_generator.generate() - }; let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); + let stream_id_clone = stream_id.clone(); + let transport_clone = transport.clone(); //Start the server runtime tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + transport_clone, + &stream_id_clone, + ping_interval, + payload_string, + ) .await { - Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), } - let _ = runtime.remove_transport(&stream_id).await; + let _ = runtime.remove_transport(&stream_id_clone).await; }); // Construct SSE stream let reader = BufReader::new(write_rx); - let message_stream = stream::unfold(reader, |mut reader| async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some((Ok(Event::default().data(trimmed_line)), reader)) + // send outgoing messages from server to the client over the sse stream + let message_stream = stream::unfold(reader, move |mut reader| { + async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + + // empty sse comment to keep-alive + if is_empty_sse_message(&trimmed_line) { + return Some((Ok(Event::default()), reader)); + } + + let (event_id, message) = match ( + resumability_enabled, + trimmed_line.split_once(char::from(ID_SEPARATOR)), + ) { + (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), + _ => (None, trimmed_line), + }; + + let event = match event_id { + Some(id) => Event::default().data(message).id(id), + None => Event::default().data(message), + }; + + Some((Ok(event), reader)) + } + Err(e) => Some((Err(e), reader)), } - Err(e) => Some((Err(e), reader)), } }); @@ -109,6 +148,23 @@ async fn create_sse_stream( HeaderValue::from_str(&session_id).unwrap(), ); + // if last_event_id exists we replay messages from the event-store + tokio::spawn(async move { + if let Some(last_event_id) = last_event_id { + if let Some(event_store) = state.event_store.as_ref() { + if let Some(events) = event_store.events_after(last_event_id).await { + for message_payload in events.messages { + // skip storing replay messages + let error = transport.write_str(&message_payload, true).await; + if let Err(error) = error { + tracing::trace!("Error replaying message: {error}") + } + } + } + } + } + }); + if !payload_contains_request { *response.status_mut() = StatusCode::ACCEPTED; } @@ -117,12 +173,12 @@ async fn create_sse_stream( // TODO: this function will be removed after refactoring the readable stream of the transports // so we would deserialize the string syncronousely and have more control over the flow -// this function could potentially add a 20-250 ns overhead which could be avoided +// this function may incur a slight runtime cost which could be avoided after refactoring fn contains_request(json_str: &str) -> Result { let value: serde_json::Value = serde_json::from_str(json_str)?; match value { serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), - serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| { item.as_object() .map(|obj| obj.contains_key("id") && obj.contains_key("method")) .unwrap_or(false) @@ -131,8 +187,22 @@ fn contains_request(json_str: &str) -> Result { } } +fn is_result(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("result")), + serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + item.as_object() + .map(|obj| obj.contains_key("result")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + pub async fn create_standalone_stream( session_id: SessionId, + last_event_id: Option, state: Arc, ) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( @@ -146,12 +216,20 @@ pub async fn create_standalone_stream( return Ok((StatusCode::CONFLICT, Json(error)).into_response()); } + if let Some(last_event_id) = last_event_id.as_ref() { + tracing::trace!( + "SSE stream re-connected with last-event-id: {}", + last_event_id + ); + } + let mut response = create_sse_stream( runtime.clone(), session_id.clone(), state.clone(), None, true, + last_event_id, ) .await?; *response.status_mut() = StatusCode::OK; @@ -166,23 +244,21 @@ pub async fn start_new_session( let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let runtime: Arc = Arc::new(server_runtime::create_server_instance( + let runtime: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); - - tracing::info!( - "a new client joined : {}", - runtime.session_id().await.unwrap_or_default().to_owned() ); + tracing::info!("a new client joined : {}", &session_id); + let response = create_sse_stream( runtime.clone(), session_id.clone(), state.clone(), Some(payload), false, + None, ) .await; @@ -227,7 +303,12 @@ async fn single_shot_stream( tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + Arc::new(transport), + &stream_id, + ping_interval, + payload_string, + ) .await { Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), @@ -236,7 +317,6 @@ async fn single_shot_stream( let _ = runtime.remove_transport(&stream_id).await; }); - // Construct SSE stream let mut reader = BufReader::new(write_rx); let mut line = String::new(); let response = match reader.read_line(&mut line).await { @@ -313,15 +393,35 @@ pub async fn process_incoming_message( match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); - - create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - ) - .await + // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport + // it should be processed by the same transport , therefore no need to call create_sse_stream + let Ok(is_result) = is_result(payload) else { + return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + }; + + if is_result { + match runtime + .consume_payload_string(DEFAULT_STREAM_ID, payload) + .await + { + Ok(()) => Ok((StatusCode::ACCEPTED, Json(())).into_response()), + Err(err) => Ok(( + StatusCode::BAD_REQUEST, + Json(SdkError::internal_error().with_message(err.to_string().as_ref())), + ) + .into_response()), + } + } else { + create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + None, + ) + .await + } } None => { let error = SdkError::session_not_found(); @@ -330,6 +430,10 @@ pub async fn process_incoming_message( } } +pub fn is_empty_sse_message(sse_payload: &str) -> bool { + sse_payload.is_empty() || sse_payload.trim() == ":" +} + pub async fn delete_session( session_id: SessionId, state: Arc, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index e1c00f8..27a16b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,3 +1,4 @@ +use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ hyper_servers::{ @@ -90,20 +91,24 @@ pub async fn handle_sse( let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); // create a transport for sending/receiving messages - let transport = SseTransport::new( + let Ok(transport) = SseTransport::new( read_rx, write_tx, read_tx, Arc::clone(&state.transport_options), - ) - .unwrap(); + ) else { + return Err(TransportServerError::TransportError( + "Failed to create SSE transport".to_string(), + )); + }; + let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let server: Arc = Arc::new(server_runtime::create_server_instance( + let server: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); state .session_store @@ -115,7 +120,12 @@ pub async fn handle_sse( // Start the server tokio::spawn(async move { match server - .start_stream(transport, DEFAULT_STREAM_ID, state.ping_interval, None) + .start_stream( + Arc::new(transport), + DEFAULT_STREAM_ID, + state.ping_interval, + None, + ) .await { Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 83cc372..67f8679 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,4 +1,4 @@ -use super::hyper_utils::{start_new_session, MCP_SESSION_ID_HEADER}; +use super::hyper_utils::start_new_session; use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, @@ -14,6 +14,7 @@ use crate::{ }, utils::valid_initialize_method, }; +use axum::routing::get; use axum::{ extract::{Query, State}, middleware, @@ -22,11 +23,9 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::SessionId; +use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; -use axum::routing::get; - pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) @@ -61,9 +60,14 @@ pub async fn handle_streamable_http_get( .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); + let last_event_id: Option = headers + .get(MCP_LAST_EVENT_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + match session_id { Some(session_id) => { - let res = create_standalone_stream(session_id, state).await?; + let res = create_standalone_stream(session_id, last_event_id, state).await?; Ok(res.into_response()) } None => { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index f093da3..71bccee 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,6 +1,8 @@ use crate::{ - error::SdkResult, mcp_server::hyper_runtime::HyperRuntime, - mcp_traits::mcp_handler::McpServerHandler, + error::SdkResult, + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_server::hyper_runtime::HyperRuntime, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; @@ -17,11 +19,11 @@ use super::{ app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - IdGenerator, InMemorySessionStore, UuidGenerator, + InMemorySessionStore, }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::TransportOptions; +use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -43,7 +45,7 @@ pub struct HyperServerOptions { pub port: u16, /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, + pub session_id_generator: Option>>, /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, @@ -51,6 +53,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -223,6 +229,7 @@ impl Default for HyperServerOptions { allowed_hosts: None, allowed_origins: None, dns_rebinding_protection: false, + event_store: None, } } } @@ -258,6 +265,7 @@ impl HyperServer { .session_id_generator .take() .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, @@ -268,6 +276,7 @@ impl HyperServer { allowed_hosts: server_options.allowed_hosts.take(), allowed_origins: server_options.allowed_origins.take(), dns_rebinding_protection: server_options.dns_rebinding_protection, + event_store: server_options.event_store.as_ref().map(Arc::clone), }); let app = app_routes(Arc::clone(&state), &server_options); Self { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index 95b2158..4384b1a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -5,7 +5,6 @@ use async_trait::async_trait; pub use in_memory::*; use rust_mcp_transport::SessionId; use tokio::sync::Mutex; -use uuid::Uuid; use crate::mcp_server::ServerRuntime; @@ -46,26 +45,3 @@ pub trait SessionStore: Send + Sync { async fn has(&self, session: &SessionId) -> bool; } - -/// Trait for generating session identifiers -/// -/// Implementors must be Send and Sync to support concurrent access. -pub trait IdGenerator: Send + Sync { - fn generate(&self) -> SessionId; -} - -/// Struct implementing the IdGenerator trait using UUID v4 -/// -/// This is a simple wrapper around the uuid crate's Uuid::new_v4 function -/// to generate unique session identifiers. -pub struct UuidGenerator {} - -impl IdGenerator for UuidGenerator { - /// Generates a new UUID v4-based session identifier - /// - /// # Returns - /// * `SessionId` - A new UUID-based session identifier as a String - fn generate(&self) -> SessionId { - Uuid::new_v4().to_string() - } -} diff --git a/crates/rust-mcp-sdk/src/id_generator.rs b/crates/rust-mcp-sdk/src/id_generator.rs new file mode 100644 index 0000000..54f0e72 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator.rs @@ -0,0 +1,5 @@ +mod fast_id_generator; +mod uuid_generator; +pub use crate::mcp_traits::IdGenerator; +pub use fast_id_generator::*; +pub use uuid_generator::*; diff --git a/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs new file mode 100644 index 0000000..fc2e976 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs @@ -0,0 +1,53 @@ +use crate::mcp_traits::IdGenerator; +use base64::Engine; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// An [`IdGenerator`] implementation optimized for lightweight, locally-scoped identifiers. +/// +/// This generator produces short, incrementing identifiers that are Base64-encoded. +/// This makes it well-suited for cases such as `StreamId` generation, where: +/// - IDs only need to be unique within a single process or session +/// - Predictability is acceptable +/// - Shorter, more human-readable identifiers are desirable +/// +pub struct FastIdGenerator { + counter: AtomicU64, + ///Optional prefix for readability + prefix: &'static str, +} + +impl FastIdGenerator { + /// Creates a new ID generator with an optional prefix. + /// + /// # Arguments + /// * `prefix` - A static string to prepend to IDs (e.g., "sid_"). + pub fn new(prefix: Option<&'static str>) -> Self { + FastIdGenerator { + counter: AtomicU64::new(0), + prefix: prefix.unwrap_or_default(), + } + } +} + +impl IdGenerator for FastIdGenerator +where + T: From, +{ + /// Generates a new session ID as a short Base64-encoded string. + /// + /// Increments an internal counter atomically and encodes it in Base64 URL-safe format. + /// The resulting ID is prefixed (if provided) and typically 8–12 characters long. + /// + /// # Returns + /// * `SessionId` - A short, unique session ID (e.g., "sid_BBBB" or "BBBB"). + fn generate(&self) -> T { + let id = self.counter.fetch_add(1, Ordering::Relaxed); + let bytes = id.to_le_bytes(); + let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + if self.prefix.is_empty() { + T::from(encoded) + } else { + T::from(format!("{}{}", self.prefix, encoded)) + } + } +} diff --git a/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs new file mode 100644 index 0000000..2f0dc21 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs @@ -0,0 +1,18 @@ +use crate::mcp_traits::IdGenerator; +use uuid::Uuid; + +/// An [`IdGenerator`] implementation that uses UUID v4 to create unique identifiers. +/// +/// This generator produces random UUIDs (version 4), which are highly unlikely +/// to collide and difficult to predict. It is therefore well-suited for +/// generating identifiers such as `SessionId` or other values where uniqueness is important. +pub struct UuidGenerator; + +impl IdGenerator for UuidGenerator +where + T: From, +{ + fn generate(&self) -> T { + T::from(Uuid::new_v4().to_string()) + } +} diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 1ea23df..a33f889 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -21,7 +21,7 @@ pub mod mcp_client { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) for an example. + //! Refer to [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) for an example. //! //! //! - **client_runtime_core**: If you need more control over MCP messages, consider using @@ -30,7 +30,7 @@ pub mod mcp_client { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) for an example. + //! Refer to [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) for an example. pub use super::mcp_handlers::mcp_client_handler::ClientHandler; pub use super::mcp_handlers::mcp_client_handler_core::ClientHandlerCore; pub use super::mcp_runtimes::client_runtime::mcp_client_runtime as client_runtime; @@ -53,7 +53,7 @@ pub mod mcp_server { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) for an example. //! //! //! - **server_runtime_core**: If you need more control over MCP messages, consider using @@ -62,7 +62,7 @@ pub mod mcp_server { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/hello-world-mcp-server-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) for an example. pub use super::mcp_handlers::mcp_server_handler::ServerHandler; pub use super::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; @@ -93,4 +93,5 @@ pub mod macros { pub use rust_mcp_macros::*; } +pub mod id_generator; pub mod schema; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index f8ee1a0..c6fb208 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -148,7 +148,7 @@ pub trait ClientHandler: Send + Sync + 'static { //********************// async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Ok(()) diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index 3bbe5c9..a0afdf1 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -38,7 +38,7 @@ pub trait ClientHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP server. async fn handle_error( &self, - error: RpcError, + error: &RpcError, runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError>; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 5b0fdc0..9b9577e 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,6 +1,7 @@ use crate::schema::{schema_utils::CallToolError, *}; use async_trait::async_trait; use serde_json::Value; +use std::sync::Arc; use crate::{mcp_traits::mcp_server::McpServer, utils::enforce_compatible_protocol_version}; @@ -15,7 +16,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, runtime: &dyn McpServer) {} + async fn on_initialized(&self, runtime: Arc) {} /// Handles the InitializeRequest from a client. /// @@ -29,7 +30,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialize_request( &self, initialize_request: InitializeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. @@ -51,6 +52,7 @@ pub trait ServerHandler: Send + Sync + 'static { runtime .set_client_details(initialize_request.params.clone()) + .await .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; Ok(server_info) @@ -64,7 +66,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_ping_request( &self, _: PingRequest, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result { Ok(Result::default()) } @@ -76,7 +78,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resources_request( &self, request: ListResourcesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -92,7 +94,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resource_templates_request( &self, request: ListResourceTemplatesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -108,7 +110,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_read_resource_request( &self, request: ReadResourceRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -124,7 +126,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_subscribe_request( &self, request: SubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -140,7 +142,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_unsubscribe_request( &self, request: UnsubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -156,7 +158,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_prompts_request( &self, request: ListPromptsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -172,7 +174,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_get_prompt_request( &self, request: GetPromptRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -188,7 +190,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -204,7 +206,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -219,7 +221,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_set_level_request( &self, request: SetLevelRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -235,7 +237,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_complete_request( &self, request: CompleteRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -251,7 +253,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_custom_request( &self, request: Value, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())) @@ -264,7 +266,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialized_notification( &self, notification: InitializedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -274,7 +276,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_cancelled_notification( &self, notification: CancelledNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -284,7 +286,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_progress_notification( &self, notification: ProgressNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -294,7 +296,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_roots_list_changed_notification( &self, notification: RootsListChangedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -318,19 +320,9 @@ pub trait ServerHandler: Send + Sync + 'static { /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. async fn handle_error( &self, - error: RpcError, - runtime: &dyn McpServer, + error: &RpcError, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } - - /// Called when the server has successfully started. - /// - /// Sends a "Server started successfully" message to stderr. - /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index fffe2fc..9275da7 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,8 +1,8 @@ +use crate::mcp_traits::mcp_server::McpServer; use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; - -use crate::mcp_traits::mcp_server::McpServer; +use std::sync::Arc; /// Defines the `ServerHandlerCore` trait for handling Model Context Protocol (MCP) server operations. /// Unlike `ServerHandler`, this trait offers no default implementations, providing full control over MCP message handling @@ -14,7 +14,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, _runtime: &dyn McpServer) {} + async fn on_initialized(&self, _runtime: Arc) {} /// Asynchronously handles an incoming request from the client. /// @@ -26,7 +26,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; /// Asynchronously handles an incoming notification from the client. @@ -36,7 +36,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_notification( &self, notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; /// Asynchronously handles an error received from the client. @@ -45,12 +45,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// - `error` – The error data received from the MCP client. async fn handle_error( &self, - error: RpcError, - runtime: &dyn McpServer, + error: &RpcError, + runtime: Arc, ) -> std::result::Result<(), RpcError>; - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs index 3bd2735..3edb344 100644 --- a/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs +++ b/crates/rust-mcp-sdk/src/mcp_macros/tool_box.rs @@ -57,15 +57,6 @@ macro_rules! tool_box { )* ] } - - #[deprecated(since = "0.2.0", note = "Use `tools()` instead.")] - pub fn get_tools() -> Vec { - vec![ - $( - $tool::tool(), - )* - ] - } } @@ -76,8 +67,20 @@ macro_rules! tool_box { /// Attempts to convert a tool request into the appropriate tool variant fn try_from(value: rust_mcp_sdk::schema::CallToolRequestParams) -> Result { - let v = serde_json::to_value(value.arguments.unwrap()) - .map_err(rust_mcp_sdk::schema::schema_utils::CallToolError::new)?; + let arguments = value + .arguments + .ok_or(rust_mcp_sdk::schema::schema_utils::CallToolError::invalid_arguments( + &value.name, + Some("Missing 'arguments' field in the request".to_string()) + ))?; + + let v = serde_json::to_value(arguments).map_err(|err| { + rust_mcp_sdk::schema::schema_utils::CallToolError::invalid_arguments( + &value.name, + Some(format!("{err}")), + ) + })?; + match value.name { $( name if name == $tool::tool_name().as_str() => { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 8d113c3..2093dc3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,70 +1,120 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; - -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, +use crate::error::{McpSdkError, SdkResult}; +use crate::id_generator::FastIdGenerator; +use crate::mcp_traits::mcp_client::McpClient; +use crate::mcp_traits::mcp_handler::McpClientHandler; +use crate::mcp_traits::IdGenerator; +use crate::utils::ensure_server_protocole_compatibility; +use crate::{ + mcp_traits::{RequestIdGen, RequestIdGenNumeric}, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, + ServerMessage, ServerMessages, + }, + InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, + RequestId, RpcError, ServerResult, }, - InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, - RpcError, ServerResult, }; use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; -use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::sync::{Arc, RwLock}; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions}; +use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::sync::Mutex; +use tokio::sync::{watch, Mutex}; -use crate::error::{McpSdkError, SdkResult}; -use crate::mcp_traits::mcp_client::McpClient; -use crate::mcp_traits::mcp_handler::McpClientHandler; -use crate::utils::ensure_server_protocole_compatibility; +pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; + +// Define a type alias for the TransportDispatcher trait object +type TransportDispatcherType = dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, +>; +type TransportType = Arc; pub struct ClientRuntime { - // The transport interface for handling messages between client and server - transport: Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + // A thread-safe map storing transport types + transport_map: tokio::sync::RwLock>, // The handler for processing MCP messages handler: Box, - // // Information about the server + // Information about the server client_details: InitializeRequestParams, - // Details about the connected server - server_details: Arc>>, handlers: Mutex>>>, + // Generator for unique request IDs + request_id_gen: Box, + // Generator for stream IDs + stream_id_gen: FastIdGenerator, + #[cfg(feature = "streamable-http")] + // Optional configuration for streamable transport + transport_options: Option, + // Flag indicating whether the client has been shut down + is_shut_down: Mutex, + // Session ID + session_id: tokio::sync::RwLock>, + // Details about the connected server + server_details_tx: watch::Sender>, + server_details_rx: watch::Receiver>, } impl ClientRuntime { pub(crate) fn new( client_details: InitializeRequestParams, - transport: impl Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, + transport: TransportType, handler: Box, ) -> Self { + let mut map: HashMap = HashMap::new(); + map.insert(DEFAULT_STREAM_ID.to_string(), transport); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); Self { - transport: Arc::new(transport), + transport_map: tokio::sync::RwLock::new(map), handler, client_details, - server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + #[cfg(feature = "streamable-http")] + transport_options: None, + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, } } - async fn initialize_request(&self) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + pub(crate) fn new_instance( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: Box, + ) -> Self { + let map: HashMap = HashMap::new(); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); + Self { + transport_map: tokio::sync::RwLock::new(map), + handler, + client_details, + handlers: Mutex::new(vec![]), + transport_options: Some(transport_options), + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, + } + } + + async fn initialize_request(self: Arc) -> SdkResult<()> { let request = InitializeRequest::new(self.client_details.clone()); let result: ServerResult = self.request(request.into(), None).await?.try_into()?; @@ -73,9 +123,15 @@ impl ClientRuntime { &self.client_details.protocol_version, &initialize_result.protocol_version, )?; - // store server details self.set_server_details(initialize_result)?; + + #[cfg(feature = "streamable-http")] + // try to create a sse stream for server initiated messages , if supported by the server + if let Err(error) = self.clone().create_sse_stream().await { + tracing::warn!("{error}"); + } + // send a InitializedNotification to the server self.send_notification(InitializedNotification::new(None).into()) .await?; @@ -84,21 +140,14 @@ impl ClientRuntime { .with_message("Incorrect response to InitializeRequest!".into()) .into()); } + Ok(()) } pub(crate) async fn handle_message( &self, message: ServerMessage, - transport: &Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + transport: &TransportType, ) -> SdkResult> { let response = match message { ServerMessage::Request(jsonrpc_request) => { @@ -123,7 +172,19 @@ impl ClientRuntime { None } ServerMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ServerMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request: {:?}", + &jsonrpc_error.id + ); + } None } ServerMessage::Response(response) => { @@ -133,7 +194,7 @@ impl ClientRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } @@ -142,28 +203,26 @@ impl ClientRuntime { }; Ok(response) } -} -#[async_trait] -impl McpClient for ClientRuntime { - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch, - { - (self.transport.message_sender().clone()) as _ - } + async fn start_standalone(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; - async fn start(self: Arc) -> SdkResult<()> { //TODO: improve the flow - let mut stream = self.transport.start().await?; - let transport = self.transport.clone(); + let mut stream = transport.start().await?; + + let transport_clone = transport.clone(); let mut error_io_stream = transport.error_stream().write().await; let error_io_stream = error_io_stream.take(); let self_clone = Arc::clone(&self); let self_clone_err = Arc::clone(&self); + // task reading from the error stream let err_task = tokio::spawn(async move { let self_ref = &*self_clone_err; @@ -171,7 +230,7 @@ impl McpClient for ClientRuntime { let mut reader = BufReader::new(error_input).lines(); loop { tokio::select! { - should_break = self_ref.transport.is_shut_down() =>{ + should_break = transport_clone.is_shut_down() =>{ if should_break { break; } @@ -201,14 +260,10 @@ impl McpClient for ClientRuntime { Ok::<(), McpSdkError>(()) }); - let transport = self.transport.clone(); + let transport = transport.clone(); + // main task reading from mcp_message stream let main_task = tokio::spawn(async move { - let sender = self_clone.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; while let Some(mcp_messages) = stream.next().await { let self_ref = &*self_clone; @@ -219,7 +274,7 @@ impl McpClient for ClientRuntime { match result { Ok(result) => { if let Some(result) = result { - sender + transport .send_message(ClientMessages::Single(result), None) .await?; } @@ -240,7 +295,7 @@ impl McpClient for ClientRuntime { let results: Vec<_> = results.into_iter().flatten().collect(); if !results.is_empty() { - sender + transport .send_message(ClientMessages::Batch(results), None) .await?; } @@ -251,44 +306,349 @@ impl McpClient for ClientRuntime { }); // send initialize request to the MCP server - self.initialize_request().await?; + self.clone().initialize_request().await?; let mut lock = self.handlers.lock().await; lock.push(main_task); lock.push(err_task); + Ok(()) + } + pub(crate) async fn store_transport( + &self, + stream_id: &str, + transport: TransportType, + ) -> SdkResult<()> { + let mut transport_map = self.transport_map.write().await; + tracing::trace!("save transport for stream id : {}", stream_id); + transport_map.insert(stream_id.to_string(), transport); Ok(()) } - fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { - match self.server_details.write() { - Ok(mut details) => { - *details = Some(server_details); - Ok(()) + pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult { + let transport_map = self.transport_map.read().await; + transport_map.get(stream_id).cloned().ok_or_else(|| { + RpcError::internal_error() + .with_message(format!("Transport for key {stream_id} not found")) + .into() + }) + } + + #[cfg(feature = "streamable-http")] + pub(crate) async fn new_transport( + &self, + session_id: Option, + standalone: bool, + ) -> SdkResult< + impl TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > { + let options = self + .transport_options + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + let transport = ClientStreamableTransport::new(options, session_id, standalone)?; + + Ok(transport) + } + + #[cfg(feature = "streamable-http")] + pub(crate) async fn create_sse_stream(self: Arc) -> SdkResult<()> { + let stream_id: StreamId = DEFAULT_STREAM_ID.into(); + let session_id = self.session_id.read().await.clone(); + let transport: Arc< + dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > = Arc::new(self.new_transport(session_id, true).await?); + let mut stream = transport.start().await?; + self.store_transport(&stream_id, transport.clone()).await?; + + let self_clone = Arc::clone(&self); + + let main_task = tokio::spawn(async move { + loop { + if let Some(mcp_messages) = stream.next().await { + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + + if let Some(result) = result { + transport + .send_message(ClientMessages::Single(result), None) + .await?; + } + } + ServerMessages::Batch(server_messages) => { + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| { + self.handle_message(server_message, &transport) + }) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport + .send_message(ClientMessages::Batch(results), None) + .await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID) { + return Ok::<_, McpSdkError>(()); + } + } else { + // end of stream + return Ok::<_, McpSdkError>(()); + } + } + }); + + let mut lock = self_clone.handlers.lock().await; + lock.push(main_task); + + Ok(()) + } + + #[cfg(feature = "streamable-http")] + pub(crate) async fn start_stream( + &self, + messages: ClientMessages, + timeout: Option, + ) -> SdkResult> { + use futures::stream::{AbortHandle, Abortable}; + let stream_id: StreamId = self.stream_id_gen.generate(); + let session_id = self.session_id.read().await.clone(); + let no_session_id = session_id.is_none(); + + let has_request = match &messages { + ClientMessages::Single(client_message) => client_message.is_request(), + ClientMessages::Batch(client_messages) => { + client_messages.iter().any(|m| m.is_request()) + } + }; + + let transport = Arc::new(self.new_transport(session_id, false).await?); + + let mut stream = transport.start().await?; + + self.store_transport(&stream_id, transport).await?; + + let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove + + let send_task = async { + let result = transport.send_message(messages, timeout).await?; + + if no_session_id { + if let Some(request_id) = transport.session_id().await.clone() { + let mut guard = self.session_id.write().await; + *guard = Some(request_id) + } } - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - Err(_) => Err(RpcError::internal_error() - .with_message("Internal Error: Failed to acquire write lock.".to_string()) - .into()), + + Ok::<_, McpSdkError>(result) + }; + + if !has_request { + return send_task.await; } + + let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair(); + + let receive_task = async { + loop { + tokio::select! { + Some(mcp_messages) = stream.next() =>{ + + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + if let Some(result) = result { + transport.send_message(ClientMessages::Single(result), None).await?; + } + } + ServerMessages::Batch(server_messages) => { + + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| self.handle_message(server_message, &transport)) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport.send_message(ClientMessages::Batch(results), None).await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID){ + return Ok::<_, McpSdkError>(()); + } + } + } + } + }; + + let receive_task = Abortable::new(receive_task, abort_recv_reg); + + // Pin the tasks to ensure they are not moved + tokio::pin!(send_task); + tokio::pin!(receive_task); + + // Run both tasks with cancellation logic + let (send_res, _) = tokio::select! { + res = &mut send_task => { + // cancel the receive_task task, to cover the case where send_task returns with error + abort_recv_handle.abort(); + (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation) + } + res = &mut receive_task => { + (send_task.await, res) + } + }; + send_res + } +} + +#[async_trait] +impl McpClient for ClientRuntime { + async fn send( + &self, + message: MessageFromClient, + request_id: Option, + request_timeout: Option, + ) -> SdkResult> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = self + .start_stream(ClientMessages::Single(mcp_message), request_timeout) + .await?; + return response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()); + } + } + + let transport_map = self.transport_map.read().await; + + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + let response = transport + .send_message(ClientMessages::Single(mcp_message), request_timeout) + .await?; + response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()) + } + + async fn send_batch( + &self, + messages: Vec, + timeout: Option, + ) -> SdkResult>> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let result = self + .start_stream(ClientMessages::Batch(messages), timeout) + .await?; + // let response = self.start_stream(&stream_id, request_id, message).await?; + return result + .map(|r| r.as_batch()) + .transpose() + .map_err(|err| err.into()); + } + } + + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + transport + .send_batch(messages, timeout) + .await + .map_err(|err| err.into()) + } + + async fn start(self: Arc) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + self.initialize_request().await?; + return Ok(()); + } + } + + self.start_standalone().await + } + + fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { + self.server_details_tx + .send(Some(server_details)) + .map_err(|_| { + RpcError::internal_error() + .with_message("Failed to set server details".to_string()) + .into() + }) } + fn client_info(&self) -> &InitializeRequestParams { &self.client_details } + fn server_info(&self) -> Option { - if let Ok(details) = self.server_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None - } + self.server_details_rx.borrow().clone() } async fn is_shut_down(&self) -> bool { - self.transport.is_shut_down().await + let result = self.is_shut_down.lock().await; + *result } + async fn shut_down(&self) -> SdkResult<()> { - self.transport.shut_down().await?; + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + let mut transport_map = self.transport_map.write().await; + let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + drop(transport_map); + for transport in transports { + let _ = transport.shut_down().await; + } // wait for tasks let mut tasks_lock = self.handlers.lock().await; @@ -297,4 +657,18 @@ impl McpClient for ClientRuntime { Ok(()) } + + async fn terminate_session(&self) { + #[cfg(feature = "streamable-http")] + { + if let Some(transport_options) = self.transport_options.as_ref() { + let session_id = self.session_id.read().await.clone(); + transport_options + .terminate_session(session_id.as_ref()) + .await; + let _ = self.shut_down().await; + } + } + let _ = self.shut_down().await; + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 9ccd4d9..43a7079 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -8,7 +8,10 @@ use crate::schema::{ InitializeRequestParams, RpcError, ServerNotification, ServerRequest, }; use async_trait::async_trait; -use rust_mcp_transport::Transport; + +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; use crate::{ error::SdkResult, mcp_client::ClientHandler, mcp_traits::mcp_handler::McpClientHandler, @@ -37,10 +40,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -51,7 +54,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandler, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientInternalHandler::new(Box::new(handler))), )) } @@ -113,7 +129,7 @@ impl McpClientHandler for ClientInternalHandler> { /// Handles errors received from the server by passing the request to self.handler async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 3bdc318..884de9d 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -1,5 +1,4 @@ -use std::sync::Arc; - +use super::ClientRuntime; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, @@ -7,17 +6,16 @@ use crate::schema::{ }, InitializeRequestParams, RpcError, }; -use async_trait::async_trait; - -use rust_mcp_transport::Transport; - use crate::{ error::SdkResult, mcp_handlers::mcp_client_handler_core::ClientHandlerCore, mcp_traits::{mcp_client::McpClient, mcp_handler::McpClientHandler}, }; - -use super::ClientRuntime; +use async_trait::async_trait; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; +use std::sync::Arc; /// Creates a new MCP client runtime with the specified configuration. /// @@ -39,10 +37,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -53,7 +51,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientCoreInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandlerCore, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientCoreInternalHandler::new(Box::new(handler))), )) } @@ -83,7 +94,7 @@ impl McpClientHandler for ClientCoreInternalHandler> async fn handle_error( &self, - jsonrpc_error: RpcError, + jsonrpc_error: &RpcError, runtime: &dyn McpClient, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 28cdd8c..5502cee 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -1,5 +1,9 @@ pub mod mcp_server_runtime; pub mod mcp_server_runtime_core; +use crate::error::SdkResult; +use crate::mcp_traits::mcp_handler::McpServerHandler; +use crate::mcp_traits::mcp_server::McpServer; +use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric}; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, @@ -7,25 +11,23 @@ use crate::schema::{ }, InitializeRequestParams, InitializeResult, RequestId, RpcError, }; - +use crate::utils::AbortTaskOnDrop; use async_trait::async_trait; use futures::future::try_join_all; use futures::{StreamExt, TryFutureExt}; - +#[cfg(feature = "hyper-server")] +use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; - use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::panic; +use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::oneshot; -use crate::error::SdkResult; -use crate::mcp_traits::mcp_handler::McpServerHandler; -use crate::mcp_traits::mcp_server::McpServer; -#[cfg(feature = "hyper-server")] -use rust_mcp_transport::SessionId; +use tokio::sync::{mpsc, oneshot, watch}; + pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; +const TASK_CHANNEL_CAPACITY: usize = 500; // Define a type alias for the TransportDispatcher trait object type TransportType = Arc< @@ -44,26 +46,34 @@ pub struct ServerRuntime { handler: Arc, // Information about the server server_details: Arc, - // Details about the connected client - client_details: Arc>>, #[cfg(feature = "hyper-server")] session_id: Option, - transport_map: tokio::sync::RwLock>, + transport_map: tokio::sync::RwLock>, //TODO: remove the transport_map, we do not need a hashmap for it + request_id_gen: Box, + client_details_tx: watch::Sender>, + client_details_rx: watch::Receiver>, } #[async_trait] impl McpServer for ServerRuntime { /// Set the client details, storing them in client_details - fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { - match self.client_details.write() { - Ok(mut details) => { - *details = Some(client_details); - Ok(()) + async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { + self.client_details_tx + .send(Some(client_details)) + .map_err(|_| { + RpcError::internal_error() + .with_message("Failed to set client details".to_string()) + .into() + }) + } + + async fn wait_for_initialization(&self) { + loop { + if self.client_details_rx.borrow().is_some() { + return; } - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - Err(_) => Err(RpcError::internal_error() - .with_message("Internal Error: Failed to acquire write lock.".to_string()) - .into()), + let mut rx = self.client_details_rx.clone(); + rx.changed().await.ok(); } } @@ -72,18 +82,26 @@ impl McpServer for ServerRuntime { message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult> { + ) -> SdkResult> { let transport_map = self.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() .with_message("transport stream does not exists or is closed!".to_string()), )?; - let mcp_message = ServerMessage::from_message(message, request_id)?; - transport + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + + let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?; + + let response = transport .send_message(ServerMessages::Single(mcp_message), request_timeout) - .map_err(|err| err.into()) - .await + .await? + .map(|res| res.as_single()) + .transpose()?; + + Ok(response) } async fn send_batch( @@ -111,17 +129,13 @@ impl McpServer for ServerRuntime { /// Returns the client information if available, after successful initialization , otherwise returns None fn client_info(&self) -> Option { - if let Ok(details) = self.client_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None - } + self.client_details_rx.borrow().clone() } /// Main runtime loop, processes incoming messages and handles requests - async fn start(&self) -> SdkResult<()> { - let transport_map = self.transport_map.read().await; + async fn start(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -130,45 +144,88 @@ impl McpServer for ServerRuntime { let mut stream = transport.start().await?; - self.handler.on_server_started(self).await; + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); // Process incoming messages from the client while let Some(mcp_messages) = stream.next().await { match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, transport).await; - - match result { - Ok(result) => { - if let Some(result) = result { - transport - .send_message(ServerMessages::Single(result), None) - .await?; + let transport = transport.clone(); + let self = self.clone(); + let tx = tx.clone(); + + // Handle incoming messages in a separate task to avoid blocking the stream. + tokio::spawn(async move { + let result = self.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + // Send result to the main loop + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send result to channel: {}", error); } - Err(error) => { - tracing::error!("Error handling message : {}", error) - } - } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); + let transport = transport.clone(); + let self = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport + .send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => Err(error), + }; - if !results.is_empty() { - transport - .send_message(ServerMessages::Batch(results), None) - .await?; - } + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } + } + + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors } + return Ok(()); } @@ -187,6 +244,11 @@ impl McpServer for ServerRuntime { } Ok(()) } + + #[cfg(feature = "hyper-server")] + fn session_id(&self) -> Option { + self.session_id.to_owned() + } } impl ServerRuntime { @@ -208,7 +270,7 @@ impl ServerRuntime { } pub(crate) async fn handle_message( - &self, + self: &Arc, message: ClientMessage, transport: &Arc< dyn TransportDispatcher< @@ -225,7 +287,7 @@ impl ServerRuntime { ClientMessage::Request(client_jsonrpc_request) => { let result = self .handler - .handle_request(client_jsonrpc_request.request, self) + .handle_request(client_jsonrpc_request.request, self.clone()) .await; // create a response to send back to the client let response: MessageFromServer = match result { @@ -247,15 +309,26 @@ impl ServerRuntime { } ClientMessage::Notification(client_jsonrpc_notification) => { self.handler - .handle_notification(client_jsonrpc_notification.notification, self) + .handle_notification(client_jsonrpc_notification.notification, self.clone()) .await?; None } ClientMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + self.handler + .handle_error(&jsonrpc_error.error, self.clone()) + .await?; + if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { + tx_response + .send(ClientMessage::Error(jsonrpc_error)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received an error response with no corresponding request {:?}", + &jsonrpc_error.id + ); + } None } - // The response is the result of a request, it is processed at the transport level. ClientMessage::Response(response) => { if let Some(tx_response) = transport.pending_request_tx(&response.id).await { tx_response @@ -263,7 +336,7 @@ impl ServerRuntime { .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; } else { tracing::warn!( - "Received response or error without a matching request: {:?}", + "Received a response with no corresponding request: {:?}", &response.id ); } @@ -286,41 +359,29 @@ impl ServerRuntime { >, >, ) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("save transport for stream id : {}", stream_id); transport_map.insert(stream_id.to_string(), transport); Ok(()) } + //TODO: re-visit and simplify unnecessary hashmap pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { - let mut transport_map = self.transport_map.write().await; + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } + let transport_map = self.transport_map.read().await; tracing::trace!("removing transport for stream id : {}", stream_id); - transport_map.remove(stream_id); + if let Some(transport) = transport_map.get(stream_id) { + transport.shut_down().await?; + } + // transport_map.remove(stream_id); Ok(()) } - pub(crate) async fn transport_by_stream( - &self, - stream_id: &str, - ) -> SdkResult< - Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - > { - let transport_map = self.transport_map.read().await; - transport_map.get(stream_id).cloned().ok_or_else(|| { - RpcError::internal_error() - .with_message(format!("Transport for key {stream_id} not found")) - .into() - }) - } - pub(crate) async fn shutdown(&self) { let mut transport_map = self.transport_map.write().await; let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); @@ -332,17 +393,24 @@ impl ServerRuntime { pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool { let transport_map = self.transport_map.read().await; - transport_map.contains_key(stream_id) + let live_transport = if let Some(t) = transport_map.get(stream_id) { + !t.is_shut_down().await + } else { + false + }; + live_transport } pub(crate) async fn start_stream( self: Arc, - transport: impl TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, + transport: Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, >, stream_id: &str, ping_interval: Duration, @@ -350,52 +418,122 @@ impl ServerRuntime { ) -> SdkResult<()> { let mut stream = transport.start().await?; - self.store_transport(stream_id, Arc::new(transport)).await?; + if stream_id == DEFAULT_STREAM_ID { + self.store_transport(stream_id, transport.clone()).await?; + } - let transport = self.transport_by_stream(stream_id).await?; + let self_clone = self.clone(); let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); - let _ = transport.keep_alive(ping_interval, disconnect_tx).await; + let abort_alive_task = transport + .keep_alive(ping_interval, disconnect_tx) + .await? + .abort_handle(); + + // ensure keep_alive task will be aborted + let _abort_guard = AbortTaskOnDrop { + handle: abort_alive_task, + }; // in case there is a payload, we consume it by transport to get processed + // payload would be message payload coming from the client if let Some(payload) = payload { - transport.consume_string_payload(&payload).await?; + if let Err(err) = transport.consume_string_payload(&payload).await { + let _ = self.remove_transport(stream_id).await; + return Err(err.into()); + } } + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + loop { tokio::select! { Some(mcp_messages) = stream.next() =>{ match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, &transport).await?; - if let Some(result) = result { - transport.send_message(ServerMessages::Single(result), None).await?; - } + let transport = transport.clone(); + let self_clone = self.clone(); + let tx = tx.clone(); + tokio::spawn(async move { + + let result = self_clone.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, &transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); - - - if !results.is_empty() { - transport.send_message(ServerMessages::Batch(results), None).await?; - } + let transport = transport.clone(); + let self_clone = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self_clone.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport.send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + }else { + Ok(None) + } + }, + Err(error) => Err(error), + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } + // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } return Ok(()); } } _ = &mut disconnect_rx => { + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } self.remove_transport(stream_id).await?; // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -405,24 +543,23 @@ impl ServerRuntime { } } - #[cfg(feature = "hyper-server")] - pub(crate) async fn session_id(&self) -> Option { - self.session_id.to_owned() - } - #[cfg(feature = "hyper-server")] pub(crate) fn new_instance( server_details: Arc, handler: Arc, session_id: SessionId, - ) -> Self { - Self { + ) -> Arc { + let (client_details_tx, client_details_rx) = + watch::channel::>(None); + Arc::new(Self { server_details, - client_details: Arc::new(RwLock::new(None)), handler, session_id: Some(session_id), transport_map: tokio::sync::RwLock::new(HashMap::new()), - } + client_details_tx, + client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + }) } pub(crate) fn new( @@ -435,16 +572,20 @@ impl ServerRuntime { ServerMessage, >, handler: Arc, - ) -> Self { + ) -> Arc { let mut map: HashMap = HashMap::new(); map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); - Self { + let (client_details_tx, client_details_rx) = + watch::channel::>(None); + Arc::new(Self { server_details: Arc::new(server_details), - client_details: Arc::new(RwLock::new(None)), handler, #[cfg(feature = "hyper-server")] session_id: None, transport_map: tokio::sync::RwLock::new(map), - } + client_details_tx, + client_details_rx, + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + }) } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 26f37e1..62fd31f 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -38,7 +38,7 @@ use crate::{ /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< @@ -49,7 +49,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandler, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -62,7 +62,7 @@ pub(crate) fn create_server_instance( server_details: Arc, handler: Arc, session_id: SessionId, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new_instance(server_details, handler, session_id) } @@ -80,7 +80,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { match client_jsonrpc_request { schema_utils::RequestFromClient::ClientRequest(client_request) => { @@ -177,8 +177,8 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, - jsonrpc_error: RpcError, - runtime: &dyn McpServer, + jsonrpc_error: &RpcError, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -187,7 +187,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { match client_jsonrpc_notification { schema_utils::NotificationFromClient::ClientNotification(client_notification) => { @@ -199,7 +199,10 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } ClientNotification::InitializedNotification(initialized_notification) => { self.handler - .handle_initialized_notification(initialized_notification, runtime) + .handle_initialized_notification( + initialized_notification, + runtime.clone(), + ) .await?; self.handler.on_initialized(runtime).await; } @@ -226,8 +229,4 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } Ok(()) } - - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 27f04df..110b20b 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -32,7 +32,7 @@ use std::sync::Arc; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< @@ -43,7 +43,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandlerCore, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -66,7 +66,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // store the client details if the request is a client initialization request if let schema_utils::RequestFromClient::ClientRequest(ClientRequest::InitializeRequest( @@ -76,6 +76,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> // keep a copy of the InitializeRequestParams which includes client_info and capabilities runtime .set_client_details(initialize_request.params.clone()) + .await .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; } @@ -86,8 +87,8 @@ impl McpServerHandler for RuntimeCoreInternalHandler> } async fn handle_error( &self, - jsonrpc_error: RpcError, - runtime: &dyn McpServer, + jsonrpc_error: &RpcError, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -95,11 +96,11 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { // Trigger the `on_initialized()` callback if an `initialized_notification` is received from the client. if client_jsonrpc_notification.is_initialized_notification() { - self.handler.on_initialized(runtime).await; + self.handler.on_initialized(runtime.clone()).await; } // handle notification @@ -108,7 +109,4 @@ impl McpServerHandler for RuntimeCoreInternalHandler> .await?; Ok(()) } - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 511731c..b66ba93 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -1,5 +1,10 @@ +pub(super) mod id_generator; #[cfg(feature = "client")] pub mod mcp_client; pub mod mcp_handler; #[cfg(feature = "server")] pub mod mcp_server; +mod request_id_gen; + +pub use id_generator::*; +pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs new file mode 100644 index 0000000..e7cb8d3 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs @@ -0,0 +1,12 @@ +/// Trait for generating unique identifiers. +/// +/// This trait is generic over the target ID type, allowing it to be used for +/// generating different kinds of identifiers such as `SessionId` or +/// transport-scoped `StreamId`. +/// +pub trait IdGenerator: Send + Sync +where + T: From, +{ + fn generate(&self) -> T; +} diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 8e72c26..5fe3fba 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -1,50 +1,35 @@ -use std::{sync::Arc, time::Duration}; - use crate::schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, - NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, + ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient, + ResultFromServer, ServerMessage, }, CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams, CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest, ListPromptsRequestParams, ListResourceTemplatesRequest, ListResourceTemplatesRequestParams, ListResourcesRequest, ListResourcesRequestParams, ListRootsRequest, ListToolsRequest, ListToolsRequestParams, - LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, + LoggingLevel, PingRequest, ReadResourceRequest, ReadResourceRequestParams, RequestId, RootsListChangedNotification, RootsListChangedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, }; use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; -use rust_mcp_transport::{McpDispatch, MessageDispatcher}; +use std::{sync::Arc, time::Duration}; #[async_trait] pub trait McpClient: Sync + Send { async fn start(self: Arc) -> SdkResult<()>; fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>; + async fn terminate_session(&self); + async fn shut_down(&self) -> SdkResult<()>; async fn is_shut_down(&self) -> bool; - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch; - fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> &InitializeRequestParams { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> Option { - self.server_info() - } - /// Checks whether the server has been initialized with client fn is_initialized(&self) -> bool { self.server_info().is_some() @@ -57,23 +42,12 @@ pub trait McpClient: Sync + Send { .map(|server_details| server_details.server_info) } - #[deprecated(since = "0.2.0", note = "Use `server_version()` instead.")] - fn get_server_version(&self) -> Option { - self.server_info() - .map(|server_details| server_details.server_info) - } - /// Returns the server's capabilities. /// After initialization has completed, this will be populated with the server's reported capabilities. fn server_capabilities(&self) -> Option { self.server_info().map(|item| item.capabilities) } - #[deprecated(since = "0.2.0", note = "Use `server_capabilities()` instead.")] - fn get_server_capabilities(&self) -> Option { - self.server_info().map(|item| item.capabilities) - } - /// Checks if the server has tools available. /// /// This function retrieves the server information and checks if the @@ -156,10 +130,6 @@ pub trait McpClient: Sync + Send { self.server_info() .map(|server_details| server_details.capabilities.logging.is_some()) } - #[deprecated(since = "0.2.0", note = "Use `instructions()` instead.")] - fn get_instructions(&self) -> Option { - self.server_info()?.instructions - } fn instructions(&self) -> Option { self.server_info()?.instructions @@ -175,27 +145,15 @@ pub trait McpClient: Sync + Send { request: RequestFromClient, timeout: Option, ) -> SdkResult { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let request_id = sender.next_request_id(); - - let mcp_message = - ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?; - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) + let response = self + .send(MessageFromClient::RequestFromClient(request), None, timeout) .await?; let server_message = response.ok_or_else(|| { RpcError::internal_error() - .with_message("An empty response was received from the server.".to_string()) + .with_message("An empty response was received from the client.".to_string()) })?; - let server_message = server_message.as_single()?; - if server_message.is_error() { return Err(server_message.as_error()?.error.into()); } @@ -205,67 +163,22 @@ pub trait McpClient: Sync + Send { async fn send( &self, - message: ClientMessage, - timeout: Option, - ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Single(message), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_single()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + message: MessageFromClient, + request_id: Option, + request_timeout: Option, + ) -> SdkResult>; async fn send_batch( &self, messages: Vec, timeout: Option, - ) -> SdkResult>> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Batch(messages), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_batch()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>>; /// Sends a notification. This is a one-way message that is not expected /// to return any response. The method asynchronously sends the notification using /// the transport layer and does not wait for any acknowledgement or result. async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?; - - sender - .send_message(ClientMessages::Single(mcp_message), None) - .await?; + self.send(notification.into(), None, None).await?; Ok(()) } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index c86a623..cb37f2a 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -6,9 +6,9 @@ use crate::schema::schema_utils::{NotificationFromClient, RequestFromClient, Res #[cfg(feature = "client")] use crate::schema::schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}; -use crate::schema::RpcError; - use crate::error::SdkResult; +use crate::schema::RpcError; +use std::sync::Arc; #[cfg(feature = "client")] use super::mcp_client::McpClient; @@ -18,18 +18,20 @@ use super::mcp_server::McpServer; #[cfg(feature = "server")] #[async_trait] pub trait McpServerHandler: Send + Sync { - async fn on_server_started(&self, runtime: &dyn McpServer); async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpServer) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: Arc, + ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; } @@ -41,8 +43,11 @@ pub trait McpClientHandler: Send + Sync { server_jsonrpc_request: RequestFromServer, runtime: &dyn McpClient, ) -> std::result::Result; - async fn handle_error(&self, jsonrpc_error: RpcError, runtime: &dyn McpClient) - -> SdkResult<()>; + async fn handle_error( + &self, + jsonrpc_error: &RpcError, + runtime: &dyn McpClient, + ) -> SdkResult<()>; async fn handle_notification( &self, server_jsonrpc_notification: NotificationFromServer, diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index cf0f168..da087d1 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -1,48 +1,40 @@ -use std::time::Duration; - use crate::schema::{ schema_utils::{ - ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer, - RequestFromServer, ResultFromClient, ServerMessage, + ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, + ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, - GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, - ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, - ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification, + ElicitRequest, ElicitRequestParams, ElicitRequestedSchema, ElicitResult, GetPromptRequest, + Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest, + ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams, + ListRootsResult, ListToolsRequest, LoggingMessageNotification, LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification, PromptListChangedNotificationParams, ReadResourceRequest, RequestId, ResourceListChangedNotification, ResourceListChangedNotificationParams, ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; -use async_trait::async_trait; - use crate::{error::SdkResult, utils::format_assertion_message}; +use async_trait::async_trait; +use rust_mcp_transport::SessionId; +use std::{sync::Arc, time::Duration}; //TODO: support options , such as enforceStrictCapabilities #[async_trait] pub trait McpServer: Sync + Send { - async fn start(&self) -> SdkResult<()>; - fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; + async fn start(self: Arc) -> SdkResult<()>; + async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; - #[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")] - fn get_client_info(&self) -> Option { - self.client_info() - } - - #[deprecated(since = "0.2.0", note = "Use `server_info()` instead.")] - fn get_server_info(&self) -> &InitializeResult { - self.server_info() - } + async fn wait_for_initialization(&self); async fn send( &self, message: MessageFromServer, request_id: Option, request_timeout: Option, - ) -> SdkResult>; + ) -> SdkResult>; async fn send_batch( &self, @@ -67,6 +59,23 @@ pub trait McpServer: Sync + Send { &self.server_info().capabilities } + /// Sends an elicitation request to the client to prompt user input and returns the received response. + /// + /// The requested_schema argument allows servers to define the structure of the expected response using a restricted subset of JSON Schema. + /// To simplify client user experience, elicitation schemas are limited to flat objects with primitive properties only + async fn elicit_input( + &self, + message: String, + requested_schema: ElicitRequestedSchema, + ) -> SdkResult { + let request: ElicitRequest = ElicitRequest::new(ElicitRequestParams { + message, + requested_schema, + }); + let response = self.request(request.into(), None).await?; + ElicitResult::try_from(response).map_err(|err| err.into()) + } + /// Sends a request to the client and processes the response. /// /// This function sends a `RequestFromServer` message to the client, waits for the response, @@ -82,13 +91,11 @@ pub trait McpServer: Sync + Send { .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; - let client_messages = response.ok_or_else(|| { + let client_message = response.ok_or_else(|| { RpcError::internal_error() .with_message("An empty response was received from the client.".to_string()) })?; - let client_message = client_messages.as_single()?; - if client_message.is_error() { return Err(client_message.as_error()?.error.into()); } @@ -415,4 +422,7 @@ pub trait McpServer: Sync + Send { } Ok(()) } + + #[cfg(feature = "hyper-server")] + fn session_id(&self) -> Option; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs new file mode 100644 index 0000000..2372ae9 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/request_id_gen.rs @@ -0,0 +1,101 @@ +use std::sync::atomic::AtomicI64; + +use crate::schema::{schema_utils::McpMessage, RequestId}; +use async_trait::async_trait; + +/// A trait for generating and managing request IDs in a thread-safe manner. +/// +/// Implementors provide functionality to generate unique request IDs, retrieve the last +/// generated ID, and reset the ID counter. +#[async_trait] +pub trait RequestIdGen: Send + Sync { + fn next_request_id(&self) -> RequestId; + #[allow(unused)] + fn last_request_id(&self) -> Option; + #[allow(unused)] + fn reset_to(&self, id: u64); + + /// Determines the request ID for an outgoing MCP message. + /// + /// For requests, generates a new ID using the internal counter. For responses or errors, + /// uses the provided `request_id`. Notifications receive no ID. + /// + /// # Arguments + /// * `message` - The MCP message to evaluate. + /// * `request_id` - An optional existing request ID (required for responses/errors). + /// + /// # Returns + /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. + fn request_id_for_message( + &self, + message: &dyn McpMessage, + request_id: Option, + ) -> Option { + // we need to produce next request_id for requests + if message.is_request() { + // request_id should be None for requests + assert!(request_id.is_none()); + Some(self.next_request_id()) + } else if !message.is_notification() { + // `request_id` must not be `None` for errors, notifications and responses + assert!(request_id.is_some()); + request_id + } else { + None + } + } +} + +pub struct RequestIdGenNumeric { + message_id_counter: AtomicI64, + last_message_id: AtomicI64, +} + +impl RequestIdGenNumeric { + pub fn new(initial_id: Option) -> Self { + Self { + message_id_counter: AtomicI64::new(initial_id.unwrap_or(0) as i64), + last_message_id: AtomicI64::new(-1), + } + } +} + +impl RequestIdGen for RequestIdGenNumeric { + /// Generates the next unique request ID as an integer. + /// + /// Increments the internal counter atomically and updates the last generated ID. + /// Uses `Relaxed` ordering for performance, as the counter only needs to ensure unique IDs. + fn next_request_id(&self) -> RequestId { + let id = self + .message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // Store the new ID as the last generated ID + self.last_message_id + .store(id, std::sync::atomic::Ordering::Relaxed); + RequestId::Integer(id) + } + + /// Returns the last generated request ID, if any. + /// + /// Returns `None` if no ID has been generated (indicated by a sentinel value of -1). + /// Uses `Relaxed` ordering since the read operation doesn’t require synchronization + /// with other memory operations beyond atomicity. + fn last_request_id(&self) -> Option { + let last_id = self + .last_message_id + .load(std::sync::atomic::Ordering::Relaxed); + if last_id == -1 { + None + } else { + Some(RequestId::Integer(last_id)) + } + } + + /// Resets the internal counter to the specified ID. + /// + /// The provided `id` (u64) is converted to i64 and stored atomically. + fn reset_to(&self, id: u64) { + self.message_id_counter + .store(id as i64, std::sync::atomic::Ordering::Relaxed); + } +} diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index de92a06..16fe7c7 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,9 +1,26 @@ use crate::schema::schema_utils::{ClientMessages, SdkError}; -use crate::error::{McpSdkError, SdkResult}; +use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; +/// A guard type that automatically aborts a Tokio task when dropped. +/// +/// This ensures that the associated task does not outlive the scope +/// of this struct, preventing runaway or leaked background tasks. +/// +pub struct AbortTaskOnDrop { + /// The handle used to abort the spawned Tokio task. + pub handle: tokio::task::AbortHandle, +} + +impl Drop for AbortTaskOnDrop { + fn drop(&mut self) { + // Automatically abort the associated task when this guard is dropped. + self.handle.abort(); + } +} + /// Formats an assertion error message for unsupported capabilities. /// /// Constructs a string describing that a specific entity (e.g., server or client) lacks @@ -54,20 +71,20 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st /// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05"); /// assert!(result.is_ok()); /// -/// // Incompatible versions (client < server) +/// // Incompatible versions (requested < current) /// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2024_11_05" && server == "2025_03_26" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2024_11_05" && current == "2025_03_26" /// )); /// -/// // Incompatible versions (client > server) +/// // Incompatible versions (requested > current) /// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -76,10 +93,12 @@ pub fn ensure_server_protocole_compatibility( server_protocol_version: &str, ) -> SdkResult<()> { match client_protocol_version.cmp(server_protocol_version) { - Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(()), } } @@ -123,8 +142,8 @@ pub fn ensure_server_protocole_compatibility( /// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -134,10 +153,12 @@ pub fn enforce_compatible_protocol_version( ) -> SdkResult> { match client_protocol_version.cmp(server_protocol_version) { // if client protocol version is higher - Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(None), Ordering::Less => { // return the same version that was received from the client @@ -147,7 +168,10 @@ pub fn enforce_compatible_protocol_version( } pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> { - let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?; + let _mcp_protocol_version = + ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol { + kind: ProtocolErrorKind::ParseError(err), + })?; Ok(()) } diff --git a/crates/rust-mcp-sdk/tests/check_imports.rs b/crates/rust-mcp-sdk/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-sdk/tests/check_imports.rs +++ b/crates/rust-mcp-sdk/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 57a3ea8..d6b45f7 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -1,5 +1,8 @@ +mod mock_server; +mod test_client; mod test_server; use async_trait::async_trait; +pub use mock_server::*; use reqwest::{Client, Response, Url}; use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; @@ -8,16 +11,31 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::sync::Once; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::time::timeout; use tokio_stream::StreamExt; +use tracing_subscriber::EnvFilter; +use wiremock::{MockServer, Request, ResponseTemplate}; +pub use test_client::*; pub use test_server::*; pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything"; #[cfg(unix)] pub const UVX_SERVER_GIT: &str = "mcp-server-git"; +static INIT: Once = Once::new(); +pub fn init_tracing() { + INIT.call_once(|| { + let filter = EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("tracing")) + .unwrap(); + + tracing_subscriber::fmt().with_env_filter(filter).init(); + }); +} #[mcp_tool( name = "say_hello", description = "Accepts a person's name and says a personalized \"Hello\" to that person", @@ -120,16 +138,20 @@ pub async fn send_get_request( ); } } + client.get(url).headers(headers).send().await } use futures::stream::Stream; // stream: &mut impl Stream>, +/// reads sse events and return them as (id, event, data) tuple pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), -) -> Option { + event_count: usize, +) -> Option, Option, String)>> { let mut buffer = String::new(); + let mut events = vec![]; while let Some(item) = stream.next().await { match item { @@ -138,42 +160,55 @@ pub async fn read_sse_event_from_stream( buffer.push_str(chunk_str); while let Some(pos) = buffer.find("\n\n") { - let data = { - // Scope to limit borrows - let (event_str, rest) = buffer.split_at(pos); - let mut data = None; - - // Process the event string - for line in event_str.lines() { - if line.starts_with("data:") { - data = Some(line.trim_start_matches("data:").trim().to_string()); - break; // Exit loop after finding data - } + let (event_str, rest) = buffer.split_at(pos); + let mut id = None; + let mut event = None; + let mut data = None; + + // Process the event string + for line in event_str.lines() { + if line.starts_with("id:") { + id = Some(line.trim_start_matches("id:").trim().to_string()); + } else if line.starts_with("event:") { + event = Some(line.trim_start_matches("event:").trim().to_string()); + } else if line.starts_with("data:") { + data = Some(line.trim_start_matches("data:").trim().to_string()); } + } - // Update buffer after processing - buffer = rest[2..].to_string(); // Skip "\n\n" - data - }; + // Update buffer after processing + buffer = rest[2..].to_string(); // Skip "\n\n" - // Return if data was found + // Only include events with data if let Some(data) = data { - return Some(data); + events.push((id, event, data)); + if events.len().eq(&event_count) { + return Some(events); + } } } } Err(_e) => { - // return Err(TransportServerError::HyperError(e)); return None; } } } - None + if !events.is_empty() { + Some(events) + } else { + None + } } -pub async fn read_sse_event(response: Response) -> Option { +// return sse event as (id, event, data) tuple +pub async fn read_sse_event( + response: Response, + event_count: usize, +) -> Option, Option, String)>> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream).await + let events = read_sse_event_from_stream(&mut stream, event_count).await; + // drop(stream); + events } pub fn test_client_info() -> InitializeRequestParams { @@ -269,9 +304,16 @@ pub fn random_port_old() -> u16 { } pub mod sample_tools { + use std::{sync::Arc, time::Duration}; + + use rust_mcp_schema::{LoggingMessageNotificationParams, TextContent}; #[cfg(feature = "2025_06_18")] use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; - use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult}; + use rust_mcp_sdk::{ + schema::{schema_utils::CallToolError, CallToolResult}, + McpServer, + }; + use serde_json::json; //****************// // SayHelloTool // @@ -331,4 +373,90 @@ pub mod sample_tools { return Ok(CallToolResult::text_content(goodbye_message, None)); } } + + //****************************// + // StartNotificationStream // + //****************************// + #[mcp_tool( + name = "start-notification-stream", + description = "Accepts a person's name and says a personalized \"Goodbye\" to that person." + )] + #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] + pub struct StartNotificationStream { + /// Interval in milliseconds between notifications + interval: u64, + /// Number of notifications to send (0 for 100) + count: u32, + } + impl StartNotificationStream { + pub async fn call_tool( + &self, + runtime: Arc, + ) -> Result { + for i in 0..self.count { + let _ = runtime + .send_logging_message(LoggingMessageNotificationParams { + data: json!({"id":format!("message {} of {}",i,self.count)}), + level: rust_mcp_sdk::schema::LoggingLevel::Emergency, + logger: None, + }) + .await; + tokio::time::sleep(Duration::from_millis(self.interval)).await; + } + + let message = "so many messages sent".to_string(); + Ok(CallToolResult::text_content(vec![TextContent::from( + message, + )])) + } + } +} + +pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { + let requests = mock_server.received_requests().await.unwrap(); + requests[index].clone() +} + +pub async fn debug_wiremock(mock_server: &MockServer) { + let requests = mock_server.received_requests().await.unwrap(); + let len = requests.len(); + println!(">>> {len} request(s) received <<<"); + + for (index, request) in requests.iter().enumerate() { + println!("\n--- #{index} of {len} ---"); + println!("Method: {}", request.method); + println!("Path: {}", request.url.path()); + // println!("Headers: {:#?}", request.headers); + println!("---- headers ----"); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + let body_str = String::from_utf8_lossy(&request.body); + println!("Body: {body_str}\n"); + } +} + +pub fn create_sse_response(payload: &str) -> ResponseTemplate { + let sse_body = format!(r#"data: {}{}"#, payload, "\n\n"); + ResponseTemplate::new(200).set_body_raw(sse_body.into_bytes(), "text/event-stream") +} + +pub async fn wait_for_n_requests( + mock_server: &MockServer, + num_requests: usize, + duration: Option, +) { + let duration = duration.unwrap_or(Duration::from_secs(1)); + timeout(duration, async { + loop { + let requests = mock_server.received_requests().await.unwrap(); + if requests.len() >= num_requests { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) + .await + .unwrap(); } diff --git a/crates/rust-mcp-sdk/tests/common/mock_server.rs b/crates/rust-mcp-sdk/tests/common/mock_server.rs new file mode 100644 index 0000000..f5b533a --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/mock_server.rs @@ -0,0 +1,528 @@ +use axum::{ + body::Body, + extract::Request, + http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Response, Sse, + }, + routing::any, + Router, +}; +use core::fmt; +use futures::stream; +use std::collections::VecDeque; +use std::{future::Future, net::SocketAddr, pin::Pin}; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; +use tokio::net::TcpListener; + +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl ToString for SseEvent { + fn to_string(&self) -> String { + let mut s = String::new(); + + if let Some(id) = &self.id { + s.push_str("id: "); + s.push_str(id); + s.push('\n'); + } + + if let Some(event) = &self.event { + s.push_str("event: "); + s.push_str(event); + s.push('\n'); + } + + if let Some(data) = &self.data { + // Convert bytes to string safely, fallback if invalid UTF-8 + for line in data.lines() { + s.push_str("data: "); + s.push_str(line); + s.push('\n'); + } + } + + s.push('\n'); // End of event + s + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self.data.as_ref(); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +// RequestRecord stores the history of incoming requests +#[derive(Clone, Debug)] +pub struct RequestRecord { + pub method: Method, + pub path: String, + pub headers: HeaderMap, + pub body: String, +} + +#[derive(Clone, Debug)] +pub struct ResponseRecord { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: String, +} + +// pub type BoxedStream = +// Pin> + Send>>; +// pub type BoxedSseResponse = Sse; + +// pub type AsyncResponseFn = +// Box Pin + Send>> + Send + Sync>; + +type AsyncResponseFn = + Box Pin + Send>> + Send + Sync>; + +// Mock defines a single mock response configuration +// #[derive(Clone)] +pub struct Mock { + method: Method, + path: String, + response: String, + response_func: Option, + header_map: HeaderMap, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +// MockBuilder is a factory for creating Mock instances +pub struct MockBuilder { + method: Method, + path: String, + response: String, + header_map: HeaderMap, + response_func: Option, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +impl MockBuilder { + fn new(method: Method, path: String, response: String, header_map: HeaderMap) -> Self { + Self { + method, + path, + response, + response_func: None, + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + fn new_with_func( + method: Method, + path: String, + response_func: AsyncResponseFn, + header_map: HeaderMap, + ) -> Self { + Self { + method, + path, + response: String::new(), + response_func: Some(response_func), + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + pub fn new_breakable_sse( + method: Method, + path: String, + repeating_message: SseEvent, + interval: Duration, + repeat: usize, + ) -> Self { + let message = Arc::new(repeating_message); + let interval = interval; + let max_repeats = repeat; + + let response_fn: AsyncResponseFn = Box::new({ + let message = Arc::clone(&message); + move || { + let message = Arc::clone(&message); + + Box::pin(async move { + // Construct SSE stream with 10 static messages using unfold + let message_stream = stream::unfold(0, move |count| { + let message = Arc::clone(&message); + + async move { + if count >= max_repeats { + return Some(( + Err(std::io::Error::other("Message limit reached")), + count, + )); + } + tokio::time::sleep(interval).await; + + Some(( + Ok(Event::default() + .data(message.data.clone().unwrap_or("".into())) + .id(message.id.clone().unwrap_or(format!("msg-id_{count}"))) + .event(message.event.clone().unwrap_or("message".into()))), + count + 1, + )) + } + }); + + let sse_stream = Sse::new(message_stream) + .keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); + + sse_stream.into_response() + }) + } + }); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + Self::new_with_func(method, path, response_fn, header_map) + } + + pub fn with_matcher(mut self, matcher: F) -> Self + where + F: Fn(&str, &HeaderMap) -> bool + Send + Sync + 'static, + { + self.matcher = Some(Arc::new(matcher)); + self + } + + pub fn add_header(mut self, key: HeaderName, val: HeaderValue) -> Self { + self.header_map.insert(key, val); + self + } + + pub fn without_matcher(mut self) -> Self { + self.matcher = None; + self + } + + pub fn expect(mut self, num_calls: usize) -> Self { + self.remaining_calls = Some(Arc::new(Mutex::new(num_calls))); + self + } + + pub fn unlimited_calls(mut self) -> Self { + self.remaining_calls = None; + self + } + + pub fn with_status(mut self, status: StatusCode) -> Self { + self.status = status; + self + } + + pub fn build(self) -> Mock { + Mock { + method: self.method, + path: self.path, + response: self.response, + header_map: self.header_map, + matcher: self.matcher, + remaining_calls: self.remaining_calls, + status: self.status, + response_func: self.response_func, + } + } + + // add_string with text/plain + pub fn new_text(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + + Self::new(method, path, response.into(), header_map) + } + + /** + MockBuilder::new_response( + Method::GET, + "/mcp".to_string(), + Box::new(|| { + // tokio::time::sleep(Duration::from_secs(1)).await; + let json_response = Json(json!({ + "status": "ok", + "data": [1, 2, 3] + })) + .into_response(); + Box::pin(async move { json_response }) + }), + ) + .build(), + */ + pub fn new_response(method: Method, path: String, response_func: AsyncResponseFn) -> Self { + Self::new_with_func(method, path, response_func, HeaderMap::new()) + } + + // new_json with application/json + pub fn new_json(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + Self::new(method, path, response.into(), header_map) + } + + // new_sse with text/event-stream + pub fn new_sse(method: Method, path: String, response: impl Into) -> Self { + let response = format!(r#"data: {}{}"#, response.into(), '\n'); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + // ensure message ends with a \n\n , if needed + let cr = if response.ends_with("\n\n") { + "" + } else { + "\n\n" + }; + Self::new(method, path, format!("{response}{cr}"), header_map) + } + + // new_raw with application/octet-stream + pub fn new_raw(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Self::new(method, path, response.into(), header_map) + } +} + +// MockServerHandle provides access to the request history after the server starts +pub struct MockServerHandle { + history: Arc>>, +} + +impl MockServerHandle { + pub async fn get_history(&self) -> Vec<(RequestRecord, ResponseRecord)> { + let history = self.history.lock().unwrap(); + history.iter().cloned().collect() + } + + pub async fn print(&self) { + let requests = self.get_history().await; + + let len = requests.len(); + println!("\n>>> {len} request(s) received <<<"); + + for (index, (request, response)) in requests.iter().enumerate() { + println!( + "\n--- Request {} of {len} ------------------------------------", + index + 1 + ); + println!("Method: {}", request.method); + println!("Path: {}", request.path); + // println!("Headers: {:#?}", request.headers); + println!("> headers "); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + println!("\n> Body"); + println!("{}\n", &request.body); + + println!(">>>>> Response <<<<<"); + println!("> status: {}", response.status); + println!("> headers"); + for (key, values) in &response.headers { + println!("{key}: {values:?}"); + } + println!("> Body"); + println!("{}", &response.body); + } + } +} + +// MockServer is the main struct for configuring and starting the mock server +pub struct SimpleMockServer { + mocks: Vec, + history: Arc>>, +} + +impl Default for SimpleMockServer { + fn default() -> Self { + Self::new() + } +} + +impl SimpleMockServer { + pub fn new() -> Self { + Self { + mocks: Vec::new(), + history: Arc::new(Mutex::new(VecDeque::new())), + } + } + + pub async fn start_with_mocks(mocks: Vec) -> (String, MockServerHandle) { + let mut server = SimpleMockServer::new(); + server.add_mocks(mocks); + server.start().await + } + + // Generic add function + pub fn add_mock_builder(&mut self, builder: MockBuilder) -> &mut Self { + self.mocks.push(builder.build()); + self + } + + pub fn add_mock(&mut self, mock: Mock) -> &mut Self { + self.mocks.push(mock); + self + } + + pub fn add_mocks(&mut self, mock: Vec) -> &mut Self { + mock.into_iter().for_each(|m| self.mocks.push(m)); + self + } + + pub async fn start(self) -> (String, MockServerHandle) { + let mocks = Arc::new(self.mocks); + let history = Arc::clone(&self.history); + + async fn handler( + mocks: Arc>, + history: Arc>>, + mut req: Request, + ) -> impl IntoResponse { + // Take ownership of the body using std::mem::take + let body = std::mem::take(req.body_mut()); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + + let request_record = RequestRecord { + method: req.method().clone(), + path: req.uri().path().to_string(), + headers: req.headers().clone(), + body: body_str.clone(), + }; + + for m in mocks.iter() { + if m.method != *req.method() || m.path != req.uri().path() { + continue; + } + + if let Some(matcher) = &m.matcher { + if !(matcher)(&body_str, req.headers()) { + continue; + } + } + + if let Some(remaining) = &m.remaining_calls { + let mut rem = remaining.lock().unwrap(); + if *rem == 0 { + continue; + } + *rem -= 1; + } + + let mut resp = match m.response_func.as_ref() { + Some(get_response) => get_response().await.into_response(), + None => Response::new(Body::from(m.response.clone())), + }; + + // if let Some(resp_box) = &mut m.response_func.take() { + // let response = resp_box.into_response(); + // // *response.status_mut() = m.status; + // // m.response_func = Some(Box::new(response)); + // } + + // let mut resp = m.response_func.as_ref().unwrap().clone().to_owned(); + // let resp = *resp; + // *resp.into_response().status_mut() = m.status; + + // let mut response = m.response_func.as_ref().unwrap().clone(); + // let mut response = m.response_func.as_ref().unwrap().clone().to_owned(); + // let mut m = *response; + // *response.status_mut() = m.status; + // let resp = &*m.response_func.as_ref().unwrap().to_owned().clone().deref(); + + // let response = boxed_response.into_response(); + + // let mut resp = Response::new(Body::from(m.response.clone())); + *resp.status_mut() = m.status; + m.header_map.iter().for_each(|(k, v)| { + resp.headers_mut().insert(k, v.clone()); + }); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: m.response.clone(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + return resp; + } + + let resp = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: "".into(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + resp + } + + let app = Router::new().route( + "/{*path}", + any(move |req: Request| handler(Arc::clone(&mocks), Arc::clone(&history), req)), + ); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let url = format!("http://{local_addr}/"); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + ( + url, + MockServerHandle { + history: self.history, + }, + ) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs new file mode 100644 index 0000000..46a8525 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -0,0 +1,163 @@ +use async_trait::async_trait; +use rust_mcp_schema::{schema_utils::MessageFromServer, PingRequest, RpcError}; +use rust_mcp_sdk::{mcp_client::ClientHandler, McpClient}; +use serde_json::json; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[cfg(feature = "hyper-server")] +pub mod test_client_common { + use rust_mcp_schema::{ + schema_utils::MessageFromServer, ClientCapabilities, Implementation, + InitializeRequestParams, LATEST_PROTOCOL_VERSION, + }; + use rust_mcp_sdk::{ + mcp_client::{client_runtime, ClientRuntime}, + McpClient, RequestOptions, SessionId, StreamableTransportOptions, + }; + use std::{collections::HashMap, sync::Arc, time::Duration}; + use tokio::sync::RwLock; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + use wiremock::{ + matchers::{body_json_string, method, path}, + Mock, MockServer, ResponseTemplate, + }; + + use crate::common::{ + create_sse_response, test_server_common::INITIALIZE_RESPONSE, wait_for_n_requests, + }; + + pub struct InitializedClient { + pub client: Arc, + pub mcp_url: String, + pub mock_server: MockServer, + } + + pub const TEST_SESSION_ID: &str = "test-session-id"; + pub const INITIALIZE_REQUEST: &str = r#"{"id":0,"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{},"clientInfo":{"name":"simple-rust-mcp-client-sse","title":"Simple Rust MCP Client (SSE)","version":"0.1.0"},"protocolVersion":"2025-06-18"}}"#; + + pub fn test_client_details() -> InitializeRequestParams { + InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + } + } + + pub async fn create_client( + mcp_url: &str, + custom_headers: Option>, + ) -> (Arc, Arc>>) { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let client_details: InitializeRequestParams = test_client_details(); + + let transport_options = StreamableTransportOptions { + mcp_url: mcp_url.to_string(), + request_options: RequestOptions { + request_timeout: Duration::from_secs(2), + custom_headers, + ..RequestOptions::default() + }, + }; + + let message_history = Arc::new(RwLock::new(vec![])); + let handler = super::TestClientHandler { + message_history: message_history.clone(), + }; + + let client = + client_runtime::with_transport_options(client_details, transport_options, handler); + + // client.clone().start().await.unwrap(); + (client, message_history) + } + + pub async fn initialize_client( + session_id: Option, + custom_headers: Option>, + ) -> InitializedClient { + let mock_server = MockServer::start().await; + + // initialize response + let mut response = create_sse_response(INITIALIZE_RESPONSE); + + if let Some(session_id) = session_id { + response = response.append_header("mcp-session-id", session_id.as_str()); + } + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, custom_headers).await; + + client.clone().start().await.unwrap(); + + wait_for_n_requests(&mock_server, 2, None).await; + + InitializedClient { + client, + mcp_url, + mock_server, + } + } +} + +// Custom responder for SSE with 10 ping messages +struct SsePingResponder; + +// Test handler +pub struct TestClientHandler { + message_history: Arc>>, +} + +impl TestClientHandler { + async fn register_message(&self, message: &MessageFromServer) { + let mut lock = self.message_history.write().await; + lock.push(message.clone()); + } +} + +#[async_trait] +impl ClientHandler for TestClientHandler { + async fn handle_ping_request( + &self, + request: PingRequest, + runtime: &dyn McpClient, + ) -> std::result::Result { + self.register_message(&request.into()).await; + + Ok(rust_mcp_schema::Result { + meta: Some(json!({"meta_number":1515}).as_object().unwrap().to_owned()), + extra: None, + }) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index aa8e2fb..d64244b 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,36 +1,38 @@ #[cfg(feature = "hyper-server")] pub mod test_server_common { + use crate::common::sample_tools::SayHelloTool; use async_trait::async_trait; use rust_mcp_schema::schema_utils::CallToolError; use rust_mcp_schema::{ CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::event_store::EventStore; + use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; - use tokio_stream::StreamExt; - use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequest, InitializeRequestParams, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ - mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, + mcp_server::{hyper_server, HyperServer, HyperServerOptions, ServerHandler}, McpServer, SessionId, }; - use std::sync::RwLock; + use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::timeout; - - use crate::common::sample_tools::SayHelloTool; + use tokio_stream::StreamExt; pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + pub const INITIALIZE_RESPONSE: &str = r#"{"result":{"protocolVersion":"2025-06-18","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{}},"serverInfo":{"name":"example-servers/everything","version":"1.0.0"}},"jsonrpc":"2.0","id":0}"#; pub struct LaunchedServer { pub hyper_runtime: HyperRuntime, pub streamable_url: String, pub sse_url: String, pub sse_message_url: String, + pub event_store: Option>, } pub fn initialize_request() -> InitializeRequest { @@ -71,16 +73,10 @@ pub mod test_server_common { #[async_trait] impl ServerHandler for TestServerHandler { - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } - async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; @@ -94,7 +90,7 @@ pub mod test_server_common { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -126,6 +122,7 @@ pub mod test_server_common { let sse_url = options.sse_url(); let sse_message_url = options.sse_message_url(); + let event_store_clone = options.event_store.clone(); let server = hyper_server::create_server(test_server_details(), TestServerHandler {}, options); @@ -138,6 +135,7 @@ pub mod test_server_common { streamable_url, sse_url, sse_message_url, + event_store: event_store_clone, } } @@ -156,14 +154,17 @@ pub mod test_server_common { } } - impl IdGenerator for TestIdGenerator { - fn generate(&self) -> SessionId { + impl IdGenerator for TestIdGenerator + where + T: From, + { + fn generate(&self) -> T { let mut lock = self.generated.write().unwrap(); *lock += 1; if *lock > self.constant_ids.len() { *lock = 1; } - self.constant_ids[*lock - 1].to_owned() + T::from(self.constant_ids[*lock - 1].to_owned()) } } diff --git a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs index 5c184cf..9f2fd95 100644 --- a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs +++ b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs @@ -30,7 +30,7 @@ mod protocol_compatibility_on_server { ); handler - .handle_initialize_request(InitializeRequest::new(initialize_request), &runtime) + .handle_initialize_request(InitializeRequest::new(initialize_request), runtime) .await } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs new file mode 100644 index 0000000..ceb778a --- /dev/null +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -0,0 +1,824 @@ +#[path = "common/common.rs"] +pub mod common; + +use common::test_client_common::create_client; +use hyper::{Method, StatusCode}; +use rust_mcp_schema::{ + schema_utils::{ + ClientJsonrpcRequest, ClientMessage, MessageFromServer, RequestFromClient, + RequestFromServer, ResultFromServer, RpcMessage, ServerMessage, + }, + RequestId, ServerRequest, ServerResult, +}; +use rust_mcp_sdk::{ + error::McpSdkError, mcp_server::HyperServerOptions, McpClient, TransportError, + MCP_LAST_EVENT_ID_HEADER, +}; +use serde_json::{json, Value}; +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; +use wiremock::{ + http::{HeaderName, HeaderValue}, + matchers::{body_json_string, header, method, path}, + Mock, MockServer, ResponseTemplate, +}; + +use crate::common::{ + create_sse_response, debug_wiremock, random_port, + test_client_common::{ + initialize_client, InitializedClient, INITIALIZE_REQUEST, TEST_SESSION_ID, + }, + test_server_common::{ + create_start_server, LaunchedServer, TestIdGenerator, INITIALIZE_RESPONSE, + }, + wait_for_n_requests, wiremock_request, MockBuilder, SimpleMockServer, SseEvent, +}; + +// should send JSON-RPC messages via POST +#[tokio::test] +async fn should_send_json_rpc_messages_via_post() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should send batch messages +#[tokio::test] +async fn should_send_batch_messages() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(result + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + // debug_wiremock(&mock_server).await; +} + +// should store session ID received during initialization +#[tokio::test] +async fn should_store_session_id_received_during_initialization() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = + create_sse_response(INITIALIZE_RESPONSE).append_header("mcp-session-id", "test-session-id"); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should terminate session with DELETE request +#[tokio::test] +async fn should_terminate_session_with_delete_request() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 405 response when server doesn't support session termination +#[tokio::test] +async fn should_handle_405_unsupported_session_termination() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(405)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 404 response when session expires +#[tokio::test] +async fn should_handle_404_response_when_session_expires() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(404)) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + matches!( + result, + Err(McpSdkError::Transport(TransportError::SessionExpired)) + ); +} + +// should handle non-streaming JSON response +#[tokio::test] +async fn should_handle_non_streaming_json_response() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + let response = ResponseTemplate::new(200) + .set_body_json(json!({ + "id":1,"jsonrpc":"2.0", "result":{"something":"good"} + })) + .insert_header("Content-Type", "application/json"); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + let request = RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})); + + let result = client.request(request, None).await.unwrap(); + + let ResultFromServer::ServerResult(ServerResult::Result(result)) = result else { + panic!("Wrong result variant!") + }; + + let extra = result.extra.unwrap(); + assert_eq!(extra.get("something").unwrap(), "good"); +} + +// should handle successful initial GET connection for SSE +#[tokio::test] +async fn should_handle_successful_initial_get_connection_for_sse() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str("data: Connection established\n\n"); + + let response = ResponseTemplate::new(200) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()) +} + +#[tokio::test] +async fn should_receive_server_initiated_messaged() { + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + enable_json_response: Some(false), + ..Default::default() + }; + let LaunchedServer { + hyper_runtime, + streamable_url, + sse_url, + sse_message_url, + event_store, + } = create_start_server(server_options).await; + + let (client, message_history) = create_client(&streamable_url, None).await; + + client.clone().start().await.unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let result = hyper_runtime + .ping(&"AAA-BBB-CCC".to_string(), None) + .await + .unwrap(); + + let lock = message_history.read().await; + let ping_request = lock + .iter() + .find(|m| { + matches!( + m, + MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_) + )) + ) + }) + .unwrap(); + let MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_), + )) = ping_request + else { + panic!("Request is not a match!") + }; + assert!(result.meta.is_some()); + + let v = result.meta.unwrap().get("meta_number").unwrap().clone(); + + assert!(matches!(v, Value::Number(value) if value.as_i64().unwrap()==1515)) //1515 is passed from TestClientHandler +} + +// should attempt initial GET connection and handle 405 gracefully +#[tokio::test] +async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(405)) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str("data: Connection established\n\n"); + + let response = ResponseTemplate::new(405) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()); + + // send a batch message, runtime should work as expected with no issue + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); +} + +// should handle multiple concurrent SSE streams +#[tokio::test] +async fn should_handle_multiple_concurrent_sse_streams() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(|req: &wiremock::Request| { + let body_string = String::from_utf8(req.body.clone()).unwrap(); + if body_string.contains("test3") { + create_sse_response(r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#) + } else { + create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ) + } + }) + .expect(2) + .mount(&mock_server) + .await; + + let message_3 = RequestFromClient::CustomRequest(json!({"method": "test3", "params": {}})); + let request1 = client.send_batch(vec![message_1, message_2], None); + let request2 = client.send(message_3.into(), None, None); + + // Run them concurrently and wait for both + let (res_batch, res_single) = tokio::join!(request1, request2); + + let res_batch = res_batch.unwrap().unwrap(); + // two results for two requests in the batch + assert_eq!(res_batch.len(), 2); + assert!(res_batch.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(res_batch + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + let res_single = res_single.unwrap().unwrap(); + let ServerMessage::Response(res_single) = res_single else { + panic!("invalid respinse type, expected Result!") + }; + + assert!(matches!(res_single.id, RequestId::Integer(id) if id==1)); +} + +// should throw error when invalid content-type is received +#[tokio::test] +async fn should_throw_error_when_invalid_content_type_is_received() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":0,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "text/plain", + )) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + let Err(McpSdkError::Transport(TransportError::UnexpectedContentType(content_type))) = result + else { + panic!("Expected a TransportError::UnexpectedContentType error!"); + }; + + assert_eq!(content_type, "text/plain"); +} + +// should always send specified custom headers +#[tokio::test] +async fn should_always_send_specified_custom_headers() { + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, Some(headers)).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "application/json", + )) + .expect(1) + .mount(&mock_server) + .await; + + let _result = client.ping(None).await; + + let requests = mock_server.received_requests().await.unwrap(); + + assert_eq!(requests.len(), 4); + assert!(requests + .iter() + .all(|r| r.headers.get(&"X-Custom-Header".into()).unwrap().as_str() == "CustomValue")); + + debug_wiremock(&mock_server).await +} + +// should reconnect a GET-initiated notification stream that fails + +#[tokio::test] +async fn should_reconnect_a_get_initiated_notification_stream_that_fails() { + // Start a mock server + let mock_server = MockServer::start().await; + + // initialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // two GET Mock, each expects one call , first time it fails, second retry it succeeds + let response = ResponseTemplate::new(502) + .set_body_raw("".to_string().into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .up_to_n_times(1) + .mount(&mock_server) + .await; + + let response = ResponseTemplate::new(200) + .set_body_raw( + "data: Connection established\n\n".to_string().into_bytes(), + "text/event-stream", + ) + .append_header("Connection", "keep-alive"); + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); +} + +//****************** Resumability ****************** +// should pass lastEventId when reconnecting +#[tokio::test] +async fn should_pass_last_event_id_when_reconnecting() { + let msg = r#"{"jsonrpc":"2.0","method":"notifications/message","params":{"data":{},"level":"debug"}}"#; + + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE).build(), + MockBuilder::new_breakable_sse( + Method::GET, + "/mcp".to_string(), + SseEvent { + data: Some(msg.into()), + event: Some("message".to_string()), + id: None, + }, + Duration::from_millis(100), + 5, + ) + .expect(2) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + // give it time for re-connection + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + + let get_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::GET) + .collect(); + + // there should be more than one GET reueat, indicating reconnection + assert!(get_requests.len() > 1); + + let Some(last_get_request) = get_requests.last() else { + panic!("Unable to find last GET request!"); + }; + + let last_event_id = last_get_request + .0 + .headers + .get(axum::http::HeaderName::from_static( + MCP_LAST_EVENT_ID_HEADER, + )); + + // last-event-id should be sent + assert!( + matches!(last_event_id, Some(last_event_id) if last_event_id.to_str().unwrap().starts_with("msg-id")) + ); + + // custom headers should be passed for all GET requests + assert!(get_requests.iter().all(|r| r + .0 + .headers + .get(axum::http::HeaderName::from_str("X-Custom-Header").unwrap()) + .unwrap() + .to_str() + .unwrap() + == "CustomValue")); + + println!("last_event_id {:?} ", last_event_id.unwrap()); +} + +// should NOT reconnect a POST-initiated stream that fails +#[tokio::test] +async fn should_not_reconnect_a_post_initiated_stream_that_fails() { + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE) + .expect(1) + .build(), + MockBuilder::new_sse(Method::GET, "/mcp".to_string(), "".to_string()) + .with_status(StatusCode::METHOD_NOT_ALLOWED) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .expect(1) + .build(), + MockBuilder::new_breakable_sse( + Method::POST, + "/mcp".to_string(), + SseEvent { + data: Some("msg".to_string()), + event: None, + id: None, + }, + Duration::ZERO, + 0, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + let result = client.send_roots_list_changed(None).await; + + assert!(result.is_err()); + + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + let post_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::POST) + .collect(); + assert_eq!(post_requests.len(), 3); // initialize, initialized, root_list_changed +} + +//****************** Auth ****************** +// attempts auth flow on 401 during POST request +// invalidates all credentials on InvalidClientError during auth +// invalidates all credentials on UnauthorizedClientError during auth +//invalidates tokens on InvalidGrantError during auth + +//****************** Others ****************** +// custom fetch in auth code paths +// should support custom reconnection options +// uses custom fetch implementation if provided +// should have exponential backoff with configurable maxRetries diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs similarity index 79% rename from crates/rust-mcp-sdk/tests/test_streamable_http.rs rename to crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 08c85e8..79c9f00 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -3,17 +3,17 @@ use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, vec}; use hyper::StatusCode; use rust_mcp_schema::{ schema_utils::{ - ClientJsonrpcRequest, ClientMessage, ClientMessages, FromMessage, NotificationFromServer, - ResultFromServer, RpcMessage, SdkError, SdkErrorCodes, ServerJsonrpcNotification, - ServerJsonrpcResponse, ServerMessages, + ClientJsonrpcRequest, ClientJsonrpcResponse, ClientMessage, ClientMessages, FromMessage, + NotificationFromServer, RequestFromServer, ResultFromServer, RpcMessage, SdkError, + SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, + ServerMessages, }, - CallToolRequest, CallToolRequestParams, ListToolsRequest, LoggingLevel, + CallToolRequest, CallToolRequestParams, ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, - ServerResult, + ServerRequest, ServerResult, }; -use rust_mcp_sdk::mcp_server::HyperServerOptions; +use rust_mcp_sdk::{event_store::InMemoryEventStore, mcp_server::HyperServerOptions}; use serde_json::{json, Map, Value}; -use tokio_stream::StreamExt; use crate::common::{ random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, @@ -40,6 +40,8 @@ async fn initialize_server( "AAA-BBB-CCC".to_string() ]))), enable_json_response, + ping_interval: Duration::from_secs(1), + event_store: Some(Arc::new(InMemoryEventStore::default())), ..Default::default() }; @@ -168,8 +170,8 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -219,8 +221,8 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -290,12 +292,20 @@ async fn should_reject_invalid_session_id() { server.hyper_runtime.await_server().await.unwrap() } -async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwest::Response { +async fn get_standalone_stream( + streamable_url: &str, + session_id: &str, + last_event_id: Option<&str>, +) -> reqwest::Response { let mut headers = HashMap::new(); headers.insert("Accept", "text/event-stream , application/json"); headers.insert("mcp-session-id", session_id); headers.insert("mcp-protocol-version", "2025-03-26"); + if let Some(last_event_id) = last_event_id { + headers.insert("last-event-id", last_event_id); + } + let response = send_get_request(streamable_url, Some(headers)) .await .unwrap(); @@ -306,7 +316,7 @@ async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwes #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_messages() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -344,8 +354,8 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .await .unwrap(); - let event = read_sse_event(response).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -364,11 +374,96 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { server.hyper_runtime.await_server().await.unwrap() } +// should establish standalone SSE stream and receive server-initiated requests +#[tokio::test] +async fn should_establish_standalone_stream_and_receive_server_requests() { + let (server, session_id) = initialize_server(None).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!( + response + .headers() + .get("mcp-session-id") + .unwrap() + .to_str() + .unwrap(), + session_id + ); + + assert_eq!( + response + .headers() + .get("content-type") + .unwrap() + .to_str() + .unwrap(), + "text/event-stream" + ); + + let hyper_server = Arc::new(server.hyper_runtime); + + // Send two server-initiated request that should appear on SSE stream with a valid request_id + for _ in 0..2 { + let hyper_server_clone = hyper_server.clone(); + let session_id_clone = session_id.to_string(); + tokio::spawn(async move { + hyper_server_clone + .list_roots(&session_id_clone, None) + .await + .unwrap(); + }); + } + + for i in 0..2 { + // send responses back to the server for two server initiated requests + let json_rpc_message: ClientJsonrpcResponse = ClientJsonrpcResponse::new( + RequestId::Integer(i), + ListRootsResult { + meta: None, + roots: vec![], + } + .into(), + ); + send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + } + + // read two events from the sse stream + let events = read_sse_event(response, 2).await.unwrap(); + + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0].2).unwrap(); + + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request + else { + panic!("invalid message received!"); + }; + + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1].2).unwrap(); + + let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request + else { + panic!("invalid message received!"); + }; + + // ensure request_ids are unique + assert!(message2.id != message1.id); + + hyper_server.graceful_shutdown(ONE_MILLISECOND); +} + // should not close GET SSE stream after sending multiple server notifications #[tokio::test] async fn should_not_close_get_sse_stream() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -386,8 +481,8 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let mut stream = response.bytes_stream(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -415,8 +510,8 @@ async fn should_not_close_get_sse_stream() { .await .unwrap(); - let event = read_sse_event_from_stream(&mut stream).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification_2, @@ -439,10 +534,10 @@ async fn should_not_close_get_sse_stream() { #[tokio::test] async fn should_reject_second_sse_stream_for_the_same_session() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); - let second_response = get_standalone_stream(&server.streamable_url, &session_id).await; + let second_response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(second_response.status(), StatusCode::CONFLICT); let error_data: SdkError = second_response.json().await.unwrap(); @@ -627,8 +722,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_1.status(), StatusCode::OK); assert_eq!(response_2.status(), StatusCode::OK); - let event = read_sse_event(response_2).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_2, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -643,8 +738,8 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() "Hello, Ali!" ); - let event = read_sse_event(response_1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response_1, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -994,8 +1089,8 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { "text/event-stream" ); - let event = read_sse_event(response).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&event).unwrap(); + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0].2).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); @@ -1273,5 +1368,177 @@ async fn should_skip_all_validations_when_false() { server.hyper_runtime.await_server().await.unwrap() } -//TODO: +// should store and include event IDs in server SSE messages +#[tokio::test] +async fn should_store_and_include_event_ids_in_server_sse_messages() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // read two events + let events = read_sse_event(response, 2).await.unwrap(); + assert_eq!(events.len(), 2); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + let (second_id, _, _) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + assert!(second_id.is_some()); + + //messages should be stored and accessible + let events = server + .event_store + .unwrap() + .events_after(first_id) + .await + .unwrap(); + assert_eq!(events.messages.len(), 1); + + // deserialize the message returned by event_store + let message: ServerJsonrpcNotification = serde_json::from_str(&events.messages[0]).unwrap(); + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification2, + )) = message.notification + else { + panic!("invalid message in store!"); + }; + assert_eq!(notification2.params.data.as_str().unwrap(), "notification2"); +} + +// should store and replay MCP server tool notifications +#[tokio::test] +async fn should_store_and_replay_mcp_server_tool_notifications() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let events = read_sse_event(response, 1).await.unwrap(); + assert_eq!(events.len(), 1); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + + // sse connection is closed in read_sse_event() + // wait so server detect the disconnect and simulate a network error + tokio::time::sleep(Duration::from_secs(3)).await; + tokio::task::yield_now().await; + // we send another notification while SSE is disconnected + let _result = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // make a new standalone SSE connection to simulate a re-connection + let response = + get_standalone_stream(&server.streamable_url, &session_id, Some(&first_id)).await; + assert_eq!(response.status(), StatusCode::OK); + let events = read_sse_event(response, 1).await.unwrap(); + + assert_eq!(events.len(), 1); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification2"); +} + // should return 400 error for invalid JSON-RPC messages +// should keep stream open after sending server notifications +// NA: should reject second initialization request +// NA: should pass request info to tool callback +// NA: should reject second SSE stream even in stateless mode +// should reject requests to uninitialized server +// should accept requests with matching protocol version +// should accept when protocol version differs from negotiated version +// should call a tool with authInfo +// should calls tool without authInfo when it is optional +// should accept pre-parsed request body +// should handle pre-parsed batch messages +// should prefer pre-parsed body over request body +// should operate without session ID validation +// should handle POST requests with various session IDs in stateless mode +// should call onsessionclosed callback when session is closed via DELETE +// should not call onsessionclosed callback when not provided +// should not call onsessionclosed callback for invalid session DELETE +// should call onsessionclosed callback with correct session ID when multiple sessions exist +// should support async onsessioninitialized callback +// should support sync onsessioninitialized callback (backwards compatibility) +// should support async onsessionclosed callback +// should propagate errors from async onsessioninitialized callback +// should propagate errors from async onsessionclosed callback +// should handle both async callbacks together +// should validate both host and origin when both are configured diff --git a/crates/rust-mcp-transport/CHANGELOG.md b/crates/rust-mcp-transport/CHANGELOG.md index 1ffd363..2d692b4 100644 --- a/crates/rust-mcp-transport/CHANGELOG.md +++ b/crates/rust-mcp-transport/CHANGELOG.md @@ -1,5 +1,42 @@ # Changelog +## [0.6.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.5.0...rust-mcp-transport-v0.6.0) (2025-09-19) + + +### ⚠ BREAKING CHANGES + +* add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) +* update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) + +### πŸš€ Features + +* Add Streamable HTTP Client , multiple refactoring and improvements ([#98](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/98)) ([abb0c36](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/abb0c36126b0a397bc20a1de36c5a5a80924a01e)) +* Event store support for resumability ([#101](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/101)) ([08742bb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/08742bb9636f81ee79eda4edc192b3b8ed4c7287)) +* Update ServerHandler and ServerHandlerCore traits ([#96](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/96)) ([a2d6d23](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/a2d6d23ab59fbc34d04526e2606f747f93a8468c)) + + +### πŸ› Bug Fixes + +* Correct pending_requests instance ([#94](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/94)) ([9d8c1fb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/9d8c1fbdf3ddb7c67ce1fb7dcb8e50b8ba2e1202)) + +## [0.5.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.5.0...rust-mcp-transport-v0.5.1) (2025-08-31) + + +### πŸ› Bug Fixes + +* Correct pending_requests instance ([#94](https://github.com/rust-mcp-stack/rust-mcp-sdk/issues/94)) ([9d8c1fb](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/9d8c1fbdf3ddb7c67ce1fb7dcb8e50b8ba2e1202)) + +## [0.5.0](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.4.1...rust-mcp-transport-v0.5.0) (2025-08-19) + + +### ⚠ BREAKING CHANGES + +* improve request ID generation, remove deprecated methods and adding improvements + +### πŸš€ Features + +* Improve request ID generation, remove deprecated methods and adding improvements ([95b91aa](https://github.com/rust-mcp-stack/rust-mcp-sdk/commit/95b91aad191e1b8777ca4a02612ab9183e0276d3)) + ## [0.4.1](https://github.com/rust-mcp-stack/rust-mcp-sdk/compare/rust-mcp-transport-v0.4.0...rust-mcp-transport-v0.4.1) (2025-08-12) diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index 94fd5ba..8331eaf 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-mcp-transport" -version = "0.4.1" +version = "0.6.0" authors = ["Ali Hashemi"] categories = ["data-structures"] description = "Transport implementations for the MCP (Model Context Protocol) within the rust-mcp-sdk ecosystem, enabling asynchronous data exchange and efficient message handling between MCP clients and servers." @@ -42,10 +42,12 @@ workspace = true ### FEATURES ################################################################# [features] -default = ["stdio", "sse", "2025_06_18"] # Default features +default = ["stdio", "sse", "streamable-http", "2025_06_18"] # Default features stdio = [] sse = ["reqwest"] +streamable-http = ["reqwest"] + # enabled mcp protocol version 2025_06_18 2025_06_18 = ["rust-mcp-schema/2025_06_18", "rust-mcp-schema/schema_utils"] diff --git a/crates/rust-mcp-transport/README.md b/crates/rust-mcp-transport/README.md index 23b78bf..30cad83 100644 --- a/crates/rust-mcp-transport/README.md +++ b/crates/rust-mcp-transport/README.md @@ -14,7 +14,7 @@ let transport = StdioTransport::new(TransportOptions { timeout: 60_000 })?; ``` -Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) example for a complete demonstration. +Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) example for a complete demonstration. ### For MCP Client @@ -51,7 +51,7 @@ let transport = StdioTransport::create_with_server_launch( )?; ``` -Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) example for a complete demonstration. +Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) example for a complete demonstration. --- diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index f201aa0..0a1e8f3 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -5,7 +5,7 @@ use crate::transport::Transport; use crate::utils::{ extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel, }; -use crate::{IoStream, McpDispatch, TransportOptions}; +use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions}; use async_trait::async_trait; use bytes::Bytes; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; @@ -13,8 +13,13 @@ use reqwest::Client; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; -use crate::schema::schema_utils::McpMessage; -use crate::schema::RequestId; +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; use std::cmp::Ordering; use std::collections::HashMap; use std::pin::Pin; @@ -25,7 +30,7 @@ use tokio::sync::{mpsc, oneshot, Mutex}; const DEFAULT_CHANNEL_CAPACITY: usize = 64; const DEFAULT_MAX_RETRY: usize = 5; -const DEFAULT_RETRY_TIME_SECONDS: u64 = 3; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; /// Configuration options for the Client SSE Transport @@ -102,10 +107,9 @@ where let base_url = match extract_origin(server_url) { Some(url) => url, None => { - let error_message = - format!("Failed to extract origin from server URL: {server_url}"); - tracing::error!(error_message); - return Err(TransportError::InvalidOptions(error_message)); + let message = format!("Failed to extract origin from server URL: {server_url}"); + tracing::error!(message); + return Err(TransportError::Configuration { message }); } }; @@ -145,12 +149,15 @@ where let mut header_map = HeaderMap::new(); for (key, value) in headers { - let header_name = key - .parse::() - .map_err(|e| TransportError::InvalidOptions(format!("Invalid header name: {e}")))?; - let header_value = HeaderValue::from_str(value).map_err(|e| { - TransportError::InvalidOptions(format!("Invalid header value: {e}")) - })?; + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; header_map.insert(header_name, header_value); } @@ -172,10 +179,12 @@ where } if let Some(endpoint_origin) = extract_origin(&endpoint) { if endpoint_origin.cmp(&self.base_url) != Ordering::Equal { - return Err(TransportError::InvalidOptions(format!( + return Err(TransportError::Configuration { + message: format!( "Endpoint origin does not match connection origin. expected: {} , received: {}", self.base_url, endpoint_origin - ))); + ), + }); } return Ok(endpoint); } @@ -284,8 +293,8 @@ where Some(data) => { // trim the trailing \n before making a request let body = String::from_utf8_lossy(&data).trim().to_string(); - if let Err(e) = http_post(&client_clone, &post_url, body, &custom_headers).await { - tracing::error!("Failed to POST message: {e:?}"); + if let Err(e) = http_post(&client_clone, &post_url, body,None, custom_headers.as_ref()).await { + tracing::error!("Failed to POST message: {e}"); } }, None => break, // Exit if channel is closed @@ -335,7 +344,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function for ClientSseTransport" .to_string(), )) @@ -346,7 +355,7 @@ where _: Duration, _: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(), )) } @@ -413,3 +422,55 @@ where pending_requests.remove(request_id) } } + +#[async_trait] +impl McpDispatch + for ClientSseTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload, skip_store).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientSseTransport +{ +} diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs new file mode 100644 index 0000000..edda062 --- /dev/null +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -0,0 +1,515 @@ +use crate::error::TransportError; +use crate::mcp_stream::MCPStream; + +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; +use crate::utils::{ + http_delete, http_post, CancellationTokenSource, ReadableChannel, StreamableHttpStream, + WritableChannel, +}; +use crate::{error::TransportResult, IoStream, McpDispatch, MessageDispatcher, Transport}; +use crate::{SessionId, TransportDispatcher, TransportOptions}; +use async_trait::async_trait; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use reqwest::Client; +use std::collections::HashMap; +use std::pin::Pin; +use std::{sync::Arc, time::Duration}; +use tokio::io::{BufReader, BufWriter}; +use tokio::sync::oneshot::Sender; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; + +const DEFAULT_CHANNEL_CAPACITY: usize = 64; +const DEFAULT_MAX_RETRY: usize = 5; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; +const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; + +pub struct StreamableTransportOptions { + pub mcp_url: String, + pub request_options: RequestOptions, +} + +impl StreamableTransportOptions { + pub async fn terminate_session(&self, session_id: Option<&SessionId>) { + let client = Client::new(); + match http_delete(&client, &self.mcp_url, session_id, None).await { + Ok(_) => {} + Err(TransportError::Http(status_code)) => { + tracing::info!("Session termination failed with status code {status_code}",); + } + Err(error) => { + tracing::info!("Session termination failed with error :{error}"); + } + }; + } +} + +pub struct RequestOptions { + pub request_timeout: Duration, + pub retry_delay: Option, + pub max_retries: Option, + pub custom_headers: Option>, +} + +impl Default for RequestOptions { + fn default() -> Self { + Self { + request_timeout: TransportOptions::default().timeout, + retry_delay: None, + max_retries: None, + custom_headers: None, + } + } +} + +pub struct ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + /// Optional cancellation token source for shutting down the transport + shutdown_source: tokio::sync::RwLock>, + /// Flag indicating if the transport is shut down + is_shut_down: Mutex, + /// Timeout duration for MCP messages + request_timeout: Duration, + /// HTTP client for making requests + client: Client, + /// URL for the SSE endpoint + mcp_server_url: String, + /// Delay between retry attempts + retry_delay: Duration, + /// Maximum number of retry attempts + max_retries: usize, + /// Optional custom HTTP headers + custom_headers: Option, + sse_task: tokio::sync::RwLock>>, + post_task: tokio::sync::RwLock>>, + message_sender: Arc>>>, + error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, + session_id: Arc>>, + standalone: bool, +} + +impl ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + pub fn new( + options: &StreamableTransportOptions, + session_id: Option, + standalone: bool, + ) -> TransportResult { + let client = Client::new(); + + let headers = match &options.request_options.custom_headers { + Some(h) => Some(Self::validate_headers(h)?), + None => None, + }; + + let mcp_server_url = options.mcp_url.to_owned(); + Ok(Self { + shutdown_source: tokio::sync::RwLock::new(None), + is_shut_down: Mutex::new(false), + request_timeout: options.request_options.request_timeout, + client, + mcp_server_url, + retry_delay: options + .request_options + .retry_delay + .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)), + max_retries: options + .request_options + .max_retries + .unwrap_or(DEFAULT_MAX_RETRY), + sse_task: tokio::sync::RwLock::new(None), + post_task: tokio::sync::RwLock::new(None), + custom_headers: headers, + message_sender: Arc::new(tokio::sync::RwLock::new(None)), + error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: Arc::new(tokio::sync::RwLock::new(session_id)), + standalone, + }) + } + + fn validate_headers(headers: &HashMap) -> TransportResult { + let mut header_map = HeaderMap::new(); + for (key, value) in headers { + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; + header_map.insert(header_name, header_value); + } + Ok(header_map) + } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Readable(error_stream)); + } +} + +#[async_trait] +impl Transport for ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + async fn start(&self) -> TransportResult> + where + MessageDispatcher: McpDispatch, + { + if self.standalone { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + let session_id = self.session_id.read().await.to_owned(); + + let sse_response = streamable_http + .make_standalone_stream_connection(&cancellation_token_sse, &custom_headers, None) + .await?; + + let sse_task_handle = tokio::spawn(async move { + if let Err(error) = streamable_http + .run_standalone(&cancellation_token_sse, &custom_headers, sse_response) + .await + { + if !matches!(error, TransportError::Cancelled(_)) { + tracing::warn!("{error}"); + } + } + }); + + let mut sse_task_lock = self.sse_task.write().await; + *sse_task_lock = Some(sse_task_handle); + + let post_url = self.mcp_server_url.clone(); + let client = self.client.clone(); + let custom_headers = self.custom_headers.clone(); + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some(data) => { + // trim the trailing \n before making a request + let payload = String::from_utf8_lossy(&data).trim().to_string(); + + if let Err(e) = http_post( + &client, + &post_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await{ + tracing::error!("Failed to POST message: {e}") + } + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create writable stream + let writable: Mutex>> = + Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx }))); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create( + readable, + writable, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + Ok(stream) + } else { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + // let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (write_tx, mut write_rx): ( + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + tokio::sync::mpsc::Receiver<( + String, + tokio::sync::oneshot::Sender>, + )>, + ) = tokio::sync::mpsc::channel(DEFAULT_CHANNEL_CAPACITY); // Buffer size as needed + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some((data, ack_tx)) => { + // trim the trailing \n before making a request + let payload = data.trim().to_string(); + let result = streamable_http.run(payload, &cancellation_token_sse, &custom_headers).await; + let _ = ack_tx.send(result);// Ignore error if receiver dropped + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create_with_ack( + readable, + write_tx, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) + } + } + + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ + } + + fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ + } + async fn shut_down(&self) -> TransportResult<()> { + // Trigger cancellation + let mut cancellation_lock = self.shutdown_source.write().await; + if let Some(source) = cancellation_lock.as_ref() { + source.cancel()?; + } + *cancellation_lock = None; // Clear cancellation_source + + // Mark as shut down + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + // Get task handle + let post_task = self.post_task.write().await.take(); + + // // Wait for tasks to complete with a timeout + let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS); + let shutdown_future = async { + if let Some(post_handle) = post_task { + let _ = post_handle.await; + } + Ok::<(), TransportError>(()) + }; + + tokio::select! { + result = shutdown_future => { + result // result of task completion + } + _ = tokio::time::sleep(timeout) => { + tracing::warn!("Shutdown timed out after {:?}", timeout); + Err(TransportError::ShutdownTimeout) + } + } + } + async fn is_shut_down(&self) -> bool { + let result = self.is_shut_down.lock().await; + *result + } + async fn consume_string_payload(&self, _: &str) -> TransportResult<()> { + Err(TransportError::Internal( + "Invalid invocation of consume_string_payload() function for ClientStreamableTransport" + .to_string(), + )) + } + + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } + + async fn keep_alive( + &self, + _: Duration, + _: oneshot::Sender<()>, + ) -> TransportResult> { + Err(TransportError::Internal( + "Invalid invocation of keep_alive() function for ClientStreamableTransport".to_string(), + )) + } + + async fn session_id(&self) -> Option { + let guard = self.session_id.read().await; + guard.clone() + } +} + +#[async_trait] +impl McpDispatch + for ClientStreamableTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload, skip_store).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientStreamableTransport +{ +} diff --git a/crates/rust-mcp-transport/src/constants.rs b/crates/rust-mcp-transport/src/constants.rs new file mode 100644 index 0000000..6ae0342 --- /dev/null +++ b/crates/rust-mcp-transport/src/constants.rs @@ -0,0 +1,3 @@ +pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; +pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; +pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id"; diff --git a/crates/rust-mcp-transport/src/error.rs b/crates/rust-mcp-transport/src/error.rs index 8f8b62f..a244456 100644 --- a/crates/rust-mcp-transport/src/error.rs +++ b/crates/rust-mcp-transport/src/error.rs @@ -1,11 +1,14 @@ use crate::schema::{schema_utils::SdkError, RpcError}; -use thiserror::Error; - use crate::utils::CancellationError; use core::fmt; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::Error as ReqwestError; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::StatusCode; use std::any::Any; +use std::io::Error as IoError; +use thiserror::Error; use tokio::sync::{broadcast, mpsc}; - /// A wrapper around a broadcast send error. This structure allows for generic error handling /// by boxing the underlying error into a type-erased form. #[derive(Debug)] @@ -80,31 +83,53 @@ pub type TransportResult = core::result::Result; #[derive(Debug, Error)] pub enum TransportError { - #[error("{0}")] - InvalidOptions(String), + #[error("Session expired or not found")] + SessionExpired, + + #[error("Failed to open SSE stream: {0}")] + FailedToOpenSSEStream(String), + + #[error("Unexpected content type: '{0}'")] + UnexpectedContentType(String), + + #[error("Failed to send message: {0}")] + SendFailure(String), + + #[error("I/O error: {0}")] + Io(#[from] IoError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP connection error: {0}")] + HttpConnection(#[from] ReqwestError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP error: {0}")] + Http(StatusCode), + + #[error("SDK error: {0}")] + Sdk(#[from] SdkError), + + #[error("Operation cancelled: {0}")] + Cancelled(#[from] CancellationError), + + #[error("Channel closed: {0}")] + ChannelClosed(#[from] tokio::sync::oneshot::error::RecvError), + + #[error("Configuration error: {message}")] + Configuration { message: String }, + #[error("{0}")] SendError(#[from] GenericSendError), - #[error("{0}")] - WatchSendError(#[from] GenericWatchSendError), - #[error("Send Error: {0}")] - StdioError(#[from] std::io::Error), + #[error("{0}")] JsonrpcError(#[from] RpcError), - #[error("{0}")] - SdkError(#[from] SdkError), - #[error("Process error{0}")] + + #[error("Process error: {0}")] ProcessError(String), - #[error("{0}")] - FromString(String), - #[error("{0}")] - OneshotRecvError(#[from] tokio::sync::oneshot::error::RecvError), - #[cfg(feature = "sse")] - #[error("{0}")] - SendMessageError(#[from] reqwest::Error), - #[error("Http Error: {0}")] - HttpError(u16), + + #[error("Internal error: {0}")] + Internal(String), + #[error("Shutdown timed out")] ShutdownTimeout, - #[error("Cancellation error : {0}")] - CancellationError(#[from] CancellationError), } diff --git a/crates/rust-mcp-transport/src/event_store.rs b/crates/rust-mcp-transport/src/event_store.rs new file mode 100644 index 0000000..fdc0734 --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store.rs @@ -0,0 +1,27 @@ +mod in_memory_event_store; +use async_trait::async_trait; +pub use in_memory_event_store::*; + +use crate::{EventId, SessionId, StreamId}; + +#[derive(Debug, Clone)] +pub struct EventStoreMessages { + pub session_id: SessionId, + pub stream_id: StreamId, + pub messages: Vec, +} + +#[async_trait] +pub trait EventStore: Send + Sync { + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: String, + ) -> EventId; + async fn remove_by_session_id(&self, session_id: SessionId); + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId); + async fn clear(&self); + async fn events_after(&self, last_event_id: EventId) -> Option; +} diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs new file mode 100644 index 0000000..66e738c --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -0,0 +1,274 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::collections::VecDeque; +use tokio::sync::RwLock; + +use crate::{ + event_store::{EventStore, EventStoreMessages}, + EventId, SessionId, StreamId, +}; + +const MAX_EVENTS_PER_SESSION: usize = 64; +const ID_SEPARATOR: &str = "-.-"; + +#[derive(Debug, Clone)] +struct EventEntry { + pub stream_id: StreamId, + pub time_stamp: u128, + pub message: String, +} + +#[derive(Debug)] +pub struct InMemoryEventStore { + max_events_per_session: usize, + storage_map: RwLock>>, +} + +impl Default for InMemoryEventStore { + fn default() -> Self { + Self { + max_events_per_session: MAX_EVENTS_PER_SESSION, + storage_map: Default::default(), + } + } +} + +/// In-memory implementation of the `EventStore` trait for MCP's Streamable HTTP transport. +/// +/// Stores events in a `HashMap` of session IDs to `VecDeque`s of events, with a per-session limit. +/// Events are identified by `event_id` (format: `session-.-stream-.-timestamp`) and used for SSE resumption. +/// Thread-safe via `RwLock` for concurrent access. +impl InMemoryEventStore { + /// Creates a new `InMemoryEventStore` with an optional maximum events per session. + /// + /// # Arguments + /// - `max_events_per_session`: Maximum number of events per session. Defaults to `MAX_EVENTS_PER_SESSION` (32) if `None`. + /// + /// # Returns + /// A new `InMemoryEventStore` instance with an empty `HashMap` wrapped in a `RwLock`. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(Some(10)); + /// assert_eq!(store.max_events_per_session, 10); + /// ``` + pub fn new(max_events_per_session: Option) -> Self { + Self { + max_events_per_session: max_events_per_session.unwrap_or(MAX_EVENTS_PER_SESSION), + storage_map: RwLock::new(HashMap::new()), + } + } + + /// Generates an `event_id` string from session, stream, and timestamp components. + /// + /// Format: `session-.-stream-.-timestamp`, used as a resumption cursor in SSE (`Last-Event-ID`). + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// + /// # Returns + /// A `String` in the format `session-.-stream-.-timestamp`. + fn generate_event_id( + &self, + session_id: &SessionId, + stream_id: &StreamId, + time_stamp: u128, + ) -> String { + format!("{session_id}{ID_SEPARATOR}{stream_id}{ID_SEPARATOR}{time_stamp}") + } + + /// Parses an event ID into its session, stream, and timestamp components. + /// + /// The event ID must follow the format `session-.-stream-.-timestamp`. + /// Returns `None` if the format is invalid, empty, or contains invalid characters (e.g., NULL). + /// + /// # Arguments + /// - `event_id`: The event ID string to parse. + /// + /// # Returns + /// An `Option` containing a tuple of `(session_id, stream_id, time_stamp)` as string slices, + /// or `None` if the format is invalid. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(None); + /// let event_id = "session1-.-stream1-.-12345"; + /// assert_eq!( + /// store.parse_event_id(event_id), + /// Some(("session1", "stream1", "12345")) + /// ); + /// assert_eq!(store.parse_event_id("invalid"), None); + /// ``` + pub fn parse_event_id<'a>(&self, event_id: &'a str) -> Option<(&'a str, &'a str, &'a str)> { + // Check for empty input or invalid characters (e.g., NULL) + if event_id.is_empty() || event_id.contains('\0') { + return None; + } + + // Split into exactly three parts + let parts: Vec<&'a str> = event_id.split(ID_SEPARATOR).collect(); + if parts.len() != 3 { + return None; + } + + let session_id = parts[0]; + let stream_id = parts[1]; + let time_stamp = parts[2]; + + // Ensure no part is empty + if session_id.is_empty() || stream_id.is_empty() || time_stamp.is_empty() { + return None; + } + + Some((session_id, stream_id, time_stamp)) + } +} + +#[async_trait] +impl EventStore for InMemoryEventStore { + /// Stores an event for a given session and stream, returning its `event_id`. + /// + /// Adds the event to the session’s `VecDeque`, removing the oldest event if the session + /// reaches `max_events_per_session`. + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// - `message`: The `ServerMessages` payload. + /// + /// # Returns + /// The generated `EventId` for the stored event. + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: String, + ) -> EventId { + let event_id = self.generate_event_id(&session_id, &stream_id, time_stamp); + + let mut storage_map = self.storage_map.write().await; + + tracing::trace!( + "Storing event for session: {session_id}, stream_id: {stream_id}, message: '{message}', {time_stamp} ", + ); + + let session_map = storage_map + .entry(session_id) + .or_insert_with(|| VecDeque::with_capacity(self.max_events_per_session)); + + if session_map.len() == self.max_events_per_session { + session_map.pop_front(); // remove the oldest if full + } + + let entry = EventEntry { + stream_id, + time_stamp, + message, + }; + + session_map.push_back(entry); + + event_id + } + + /// Removes all events associated with a given stream ID within a specific session. + /// + /// Removes events matching `stream_id` from the specified `session_id`’s event queue. + /// If the session’s queue becomes empty, it is removed from the store. + /// Idempotent if `session_id` or `stream_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to target. + /// - `stream_id`: The stream identifier to remove. + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId) { + let mut storage_map = self.storage_map.write().await; + + // Check if session exists + if let Some(events) = storage_map.get_mut(&session_id) { + // Remove events with the given stream_id + events.retain(|event| event.stream_id != stream_id); + // Remove session if empty + if events.is_empty() { + storage_map.remove(&session_id); + } + } + // No action if session_id doesn’t exist (idempotent) + } + + /// Removes all events associated with a given session ID. + /// + /// Removes the entire session from the store. Idempotent if `session_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to remove. + async fn remove_by_session_id(&self, session_id: SessionId) { + let mut storage_map = self.storage_map.write().await; + storage_map.remove(&session_id); + } + + /// Retrieves events after a given `event_id` for a specific session and stream. + /// + /// Parses `last_event_id` to extract `session_id`, `stream_id`, and `time_stamp`. + /// Returns events after the matching event in the session’s stream, sorted by timestamp + /// in ascending order (earliest to latest). Returns `None` if the `event_id` is invalid, + /// the session doesn’t exist, or the timestamp is non-numeric. + /// + /// # Arguments + /// - `last_event_id`: The event ID (format: `session-.-stream-.-timestamp`) to start after. + /// + /// # Returns + /// An `Option` containing `EventStoreMessages` with the session ID, stream ID, and sorted messages, + /// or `None` if no events are found or the input is invalid. + async fn events_after(&self, last_event_id: EventId) -> Option { + let Some((session_id, stream_id, time_stamp)) = self.parse_event_id(&last_event_id) else { + tracing::warn!("error parsing last event id: '{last_event_id}'"); + return None; + }; + + let storage_map = self.storage_map.read().await; + let Some(events) = storage_map.get(session_id) else { + tracing::warn!("could not find the session_id in the store : '{session_id}'"); + return None; + }; + + let Ok(time_stamp) = time_stamp.parse::() else { + tracing::warn!("could not parse the timestamp: '{time_stamp}'"); + return None; + }; + + let events = match events + .iter() + .position(|e| e.stream_id == stream_id && e.time_stamp == time_stamp) + { + Some(index) if index + 1 < events.len() => { + // Collect subsequent events that match the stream_id + let mut subsequent: Vec<_> = events + .range(index + 1..) + .filter(|e| e.stream_id == stream_id) + .cloned() + .collect(); + + subsequent.sort_by(|a, b| a.time_stamp.cmp(&b.time_stamp)); + subsequent.iter().map(|e| e.message.clone()).collect() + } + _ => vec![], + }; + + tracing::trace!("{} messages after '{last_event_id}'", events.len()); + + Some(EventStoreMessages { + session_id: session_id.to_string(), + stream_id: stream_id.to_string(), + messages: events, + }) + } + + async fn clear(&self) { + let mut storage_map = self.storage_map.write().await; + storage_map.clear(); + } +} diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 1634922..d21e5dd 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -1,25 +1,39 @@ // Copyright (c) 2025 mcp-rust-stack // Licensed under the MIT License. See LICENSE file for details. // Modifications to this file must be documented with a description of the changes made. + #[cfg(feature = "sse")] mod client_sse; +#[cfg(feature = "streamable-http")] +mod client_streamable_http; +mod constants; pub mod error; +pub mod event_store; mod mcp_stream; mod message_dispatcher; mod schema; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod sse; +#[cfg(feature = "stdio")] mod stdio; mod transport; mod utils; #[cfg(feature = "sse")] pub use client_sse::*; +#[cfg(feature = "streamable-http")] +pub use client_streamable_http::*; +pub use constants::*; pub use message_dispatcher::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub use sse::*; +#[cfg(feature = "stdio")] pub use stdio::*; pub use transport::*; // Type alias for session identifier, represented as a String pub type SessionId = String; +// Type alias for stream identifier (that will be used at the transport scope), represented as a String +pub type StreamId = String; +// Type alias for event (MCP message) identifier, represented as a String +pub type EventId = String; diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 2d2a377..0b10918 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -5,12 +5,7 @@ use crate::{ utils::CancellationToken, IoStream, }; -use std::{ - collections::HashMap, - pin::Pin, - sync::{atomic::AtomicI64, Arc}, - time::Duration, -}; +use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; use tokio::task::JoinHandle; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -57,10 +52,42 @@ impl MCPStream { // rpc message stream that receives incoming messages - let sender = MessageDispatcher::new( + let sender = MessageDispatcher::new(pending_requests, writable, request_timeout); + + (stream, sender, error_io) + } + + pub fn create_with_ack( + readable: Pin>, + writable: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + error_io: IoStream, + pending_requests: Arc>>>, + request_timeout: Duration, + cancellation_token: CancellationToken, + ) -> ( + tokio_stream::wrappers::ReceiverStream, + MessageDispatcher, + IoStream, + ) + where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + { + let (tx, rx) = tokio::sync::mpsc::channel::(CHANNEL_CAPACITY); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + // Clone cancellation_token for reader + let reader_token = cancellation_token.clone(); + + #[allow(clippy::let_underscore_future)] + let _ = Self::spawn_reader(readable, tx, reader_token); + + let sender = MessageDispatcher::new_with_acknowledgement( pending_requests, writable, - Arc::new(AtomicI64::new(0)), request_timeout, ); diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 22d0b58..cd9727c 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -1,25 +1,29 @@ -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, ServerMessages, +use crate::error::{TransportError, TransportResult}; +use crate::schema::{RequestId, RpcError}; +use crate::utils::{await_timeout, current_timestamp}; +use crate::McpDispatch; +use crate::{ + event_store::EventStore, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, + ServerMessages, + }, + JsonrpcError, }, - JsonrpcError, + SessionId, StreamId, }; -use crate::schema::{RequestId, RpcError}; use async_trait::async_trait; use futures::future::join_all; - use std::collections::HashMap; use std::pin::Pin; -use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::sync::oneshot::{self}; use tokio::sync::Mutex; -use crate::error::{TransportError, TransportResult}; -use crate::utils::await_timeout; -use crate::McpDispatch; +pub const ID_SEPARATOR: u8 = b'|'; /// Provides a dispatcher for sending MCP messages and handling responses. /// @@ -30,9 +34,18 @@ use crate::McpDispatch; /// a configurable timeout mechanism for asynchronous responses. pub struct MessageDispatcher { pending_requests: Arc>>>, - writable_std: Mutex>>, - message_id_counter: Arc, + writable_std: Option>>>, + writable_tx: Option< + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + >, request_timeout: Duration, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } impl MessageDispatcher { @@ -49,51 +62,49 @@ impl MessageDispatcher { pub fn new( pending_requests: Arc>>>, writable_std: Mutex>>, - message_id_counter: Arc, request_timeout: Duration, ) -> Self { Self { pending_requests, - writable_std, - message_id_counter, + writable_std: Some(writable_std), + writable_tx: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } - /// Determines the request ID for an outgoing MCP message. - /// - /// For requests, generates a new ID using the internal counter. For responses or errors, - /// uses the provided `request_id`. Notifications receive no ID. - /// - /// # Arguments - /// * `message` - The MCP message to evaluate. - /// * `request_id` - An optional existing request ID (required for responses/errors). - /// - /// # Returns - /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - pub fn request_id_for_message( - &self, - message: &impl McpMessage, - request_id: Option, - ) -> Option { - // we need to produce next request_id for requests - if message.is_request() { - // request_id should be None for requests - assert!(request_id.is_none()); - Some(self.next_request_id()) - } else if !message.is_notification() { - // `request_id` must not be `None` for errors, notifications and responses - assert!(request_id.is_some()); - request_id - } else { - None + pub fn new_with_acknowledgement( + pending_requests: Arc>>>, + writable_tx: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + request_timeout: Duration, + ) -> Self { + Self { + pending_requests, + writable_tx: Some(writable_tx), + writable_std: None, + request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } - pub fn next_request_id(&self) -> RequestId { - RequestId::Integer( - self.message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ) + + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); } async fn store_pending_request( @@ -158,14 +169,14 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; if let Some(rx) = rx_response { // Wait for the response with timeout match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { Ok(response) => Ok(Some(ServerMessages::Single(response))), Err(error) => match error { - TransportError::OneshotRecvError(_) => { + TransportError::ChannelClosed(_) => { Err(schema_utils::SdkError::connection_closed().into()) } _ => Err(error), @@ -187,19 +198,20 @@ impl McpDispatch }) .unzip(); + // Ensure all request IDs are stored before sending the request + let tasks = join_all(pending_tasks).await; + // send the batch messages to the server let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; // no request in the batch, no need to wait for the result - if pending_tasks.is_empty() { + if request_ids.is_empty() { return Ok(None); } - let tasks = join_all(pending_tasks).await; - let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| { rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout))) }); @@ -249,12 +261,25 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + async fn write_str(&self, payload: &str, _skip_store: bool) -> TransportResult<()> { + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(format!("{err}")))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } @@ -292,7 +317,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; if let Some(rx) = rx_response { match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { @@ -320,7 +345,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; // no request in the batch, no need to wait for the result if pending_tasks.is_empty() { @@ -378,11 +403,49 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let mut event_id = None; + + if !skip_store && !payload.trim().is_empty() { + if let (Some(session_id), Some(stream_id), Some(event_store)) = ( + self.session_id.as_ref(), + self.stream_id.as_ref(), + self.event_store.as_ref(), + ) { + event_id = Some( + event_store + .store_event( + session_id.clone(), + stream_id.clone(), + current_timestamp(), + payload.to_owned(), + ) + .await, + ) + }; + } + + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + if let Some(id) = event_id { + writable_std.write_all(id.as_bytes()).await?; + writable_std.write_all(&[ID_SEPARATOR]).await?; // separate id from message + } + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(err.to_string()))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 50dbb32..89ca67f 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -1,3 +1,4 @@ +use crate::event_store::EventStore; use crate::schema::schema_utils::{ ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, }; @@ -19,7 +20,7 @@ use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; -use crate::{IoStream, McpDispatch, SessionId, TransportDispatcher, TransportOptions}; +use crate::{IoStream, McpDispatch, SessionId, StreamId, TransportDispatcher, TransportOptions}; pub struct SseTransport where @@ -33,6 +34,10 @@ where message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, pending_requests: Arc>>>, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } /// Server-Sent Events (SSE) transport implementation @@ -67,6 +72,9 @@ where message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: None, + stream_id: None, + event_store: None, }) } @@ -86,6 +94,19 @@ where let mut lock = self.error_stream.write().await; *lock = Some(IoStream::Writable(error_stream)); } + + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } } #[async_trait] @@ -123,10 +144,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -156,12 +177,12 @@ impl Transport( + let (stream, mut sender, error_stream) = MCPStream::create::( Box::pin(read_rx), Mutex::new(Box::pin(write_tx)), IoStream::Writable(Box::pin(tokio::io::stderr())), @@ -170,6 +191,18 @@ impl Transport {} - Err(TransportError::StdioError(error)) => { + Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { let _ = disconnect_tx.send(()); break; diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 06931d2..7678c65 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -1,5 +1,6 @@ use crate::schema::schema_utils::{ - ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, + ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage, + ServerMessages, }; use crate::schema::RequestId; use async_trait::async_trait; @@ -193,30 +194,29 @@ where #[cfg(unix)] command.process_group(0); - let mut process = command.spawn().map_err(TransportError::StdioError)?; + let mut process = command.spawn().map_err(TransportError::Io)?; let stdin = process .stdin .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?; let stdout = process .stdout .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?; let stderr = process .stderr .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?; - let pending_requests_clone1 = self.pending_requests.clone(); - let pending_requests_clone2 = self.pending_requests.clone(); + let pending_requests_clone = self.pending_requests.clone(); tokio::spawn(async move { let _ = process.wait().await; // clean up pending requests to cancel waiting tasks - let mut pending_requests = pending_requests_clone1.lock().await; + let mut pending_requests = pending_requests_clone.lock().await; pending_requests.clear(); }); @@ -224,7 +224,7 @@ where Box::pin(stdout), Mutex::new(Box::pin(stdin)), IoStream::Readable(Box::pin(stderr)), - pending_requests_clone2, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); @@ -237,13 +237,11 @@ where Ok(stream) } else { - let pending_requests: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); let (stream, sender, error_stream) = MCPStream::create( Box::pin(tokio::io::stdin()), Mutex::new(Box::pin(tokio::io::stdout())), IoStream::Writable(Box::pin(tokio::io::stderr())), - pending_requests, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); @@ -277,7 +275,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(), )) } @@ -287,7 +285,7 @@ where _interval: Duration, _disconnect_tx: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for StdioTransport".to_string(), )) } @@ -350,10 +348,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -367,3 +365,55 @@ impl > for StdioTransport { } + +#[async_trait] +impl McpDispatch + for StdioTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload, skip_store).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for StdioTransport +{ +} diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index 3d17ebd..a9e7190 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -1,15 +1,12 @@ -use std::{pin::Pin, sync::Arc, time::Duration}; - -use crate::schema::RequestId; +use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; +use crate::{schema::RequestId, SessionId}; use async_trait::async_trait; - +use std::{pin::Pin, sync::Arc, time::Duration}; use tokio::{ sync::oneshot::{self, Sender}, task::JoinHandle, }; -use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; - /// Default Timeout in milliseconds const DEFAULT_TIMEOUT_MSEC: u64 = 60_000; @@ -85,7 +82,7 @@ where /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()>; + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()>; } /// A trait representing the transport layer for the MCP (Message Communication Protocol). @@ -125,6 +122,9 @@ where interval: Duration, disconnect_tx: oneshot::Sender<()>, ) -> TransportResult>; + async fn session_id(&self) -> Option { + None + } } /// A composite trait that combines both transport and dispatch capabilities for the MCP protocol. @@ -160,3 +160,26 @@ where OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { } + +// pub trait IntoClientTransport { +// type TransportType: Transport< +// ServerMessages, +// MessageFromClient, +// ServerMessage, +// ClientMessages, +// ClientMessage, +// >; + +// fn into_transport(self, session_id: Option) -> TransportResult; +// } + +// impl IntoClientTransport for T +// where +// T: Transport, +// { +// type TransportType = T; + +// fn into_transport(self, _: Option) -> TransportResult { +// Ok(self) +// } +// } diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 218d517..034f062 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,38 +1,48 @@ mod cancellation_token; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod http_utils; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod readable_channel; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +mod sse_parser; #[cfg(feature = "sse")] mod sse_stream; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +mod streamable_http_stream; +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; pub(crate) use cancellation_token::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +pub(crate) use sse_parser::*; #[cfg(feature = "sse")] pub(crate) use sse_stream::*; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +pub(crate) use streamable_http_stream::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; +mod time_utils; +pub use time_utils::*; use crate::schema::schema_utils::SdkError; use tokio::time::{timeout, Duration}; use crate::error::{TransportError, TransportResult}; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] use crate::SessionId; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where F: std::future::Future>, // The operation returns a Result - E: Into, // The error type must be convertible to TransportError + E: Into, { match timeout(timeout_duration, operation).await { - Ok(result) => result.map_err(|err| err.into()), // Convert the error type into TransportError + Ok(result) => result.map_err(|err| err.into()), Err(_) => Err(SdkError::request_timeout(timeout_duration.as_millis()).into()), // Timeout error } } @@ -46,7 +56,7 @@ where /// # Returns /// A String containing the endpoint with the session ID added as a query parameter /// -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { // Handle empty endpoint let base = if endpoint.is_empty() { "/" } else { endpoint }; diff --git a/crates/rust-mcp-transport/src/utils/http_utils.rs b/crates/rust-mcp-transport/src/utils/http_utils.rs index 701dcb0..84b62dd 100644 --- a/crates/rust-mcp-transport/src/utils/http_utils.rs +++ b/crates/rust-mcp-transport/src/utils/http_utils.rs @@ -1,7 +1,35 @@ use crate::error::{TransportError, TransportResult}; +use crate::{SessionId, MCP_SESSION_ID_HEADER}; -use reqwest::header::{HeaderMap, CONTENT_TYPE}; -use reqwest::Client; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE}; +use reqwest::{Client, Response}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseType { + EventStream, + Json, +} + +/// Determines the response type based on the `Content-Type` header. +pub async fn validate_response_type(response: &Response) -> TransportResult { + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(content_type) => { + let content_type_str = content_type.to_str().map_err(|_| { + TransportError::UnexpectedContentType("".to_string()) + })?; + + // Normalize to lowercase for case-insensitive comparison + let content_type_normalized = content_type_str.to_ascii_lowercase(); + + match content_type_normalized.as_str() { + "text/event-stream" => Ok(ResponseType::EventStream), + "application/json" => Ok(ResponseType::Json), + other => Err(TransportError::UnexpectedContentType(other.to_string())), + } + } + None => Err(TransportError::UnexpectedContentType("".to_string())), + } +} /// Sends an HTTP POST request with the given body and headers /// @@ -17,21 +45,96 @@ pub async fn http_post( client: &Client, post_url: &str, body: String, - headers: &Option, -) -> TransportResult<()> { + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { let mut request = client .post(post_url) .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") .body(body); if let Some(map) = headers { request = request.headers(map.clone()); } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + let response = request.send().await?; if !response.status().is_success() { - return Err(TransportError::HttpError(response.status().as_u16())); + return Err(TransportError::Http(response.status())); } - Ok(()) + Ok(response) +} + +pub async fn http_get( + client: &Client, + url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .get(url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + return Err(TransportError::Http(response.status())); + } + Ok(response) +} + +pub async fn http_delete( + client: &Client, + post_url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .delete(post_url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status_code = response.status(); + return Err(TransportError::Http(status_code)); + } + Ok(response) +} + +#[allow(unused)] +pub fn get_header_value(response: &Response, header_name: HeaderName) -> Option { + let content_type = response.headers().get(header_name)?.to_str().ok()?; + Some(content_type.to_string()) } pub fn extract_origin(url: &str) -> Option { @@ -88,7 +191,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -113,11 +216,11 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is an HttpError with status 400 match result { - Err(TransportError::HttpError(status)) => assert_eq!(status, 400), + Err(TransportError::Http(status)) => assert_eq!(status, 400), _ => panic!("Expected HttpError with status 400"), } } @@ -142,7 +245,7 @@ mod tests { let headers = Some(create_test_headers()); // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -157,7 +260,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, url, body, &headers).await; + let result = http_post(&client, url, body, None, headers.as_ref()).await; // Assert the result is an error (likely a connection error) assert!(result.is_err()); diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs new file mode 100644 index 0000000..5933726 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -0,0 +1,320 @@ +use core::fmt; +use std::collections::HashMap; + +use bytes::{Bytes, BytesMut}; +const BUFFER_CAPACITY: usize = 1024; + +/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. +/// +/// Contains the event type, data payload, and optional event ID. +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl std::fmt::Display for SseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(id) = &self.id { + writeln!(f, "id: {id}")?; + } + + if let Some(event) = &self.event { + writeln!(f, "event: {event}")?; + } + + if let Some(data) = &self.data { + match std::str::from_utf8(data) { + Ok(text) => { + for line in text.lines() { + writeln!(f, "data: {line}")?; + } + } + Err(_) => { + writeln!(f, "data: [binary data]")?; + } + } + } + + writeln!(f)?; // Trailing newline for SSE message end + Ok(()) + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self + .data + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +/// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. +/// This Parser is specifically designed for MCP messages and with no multi-line data support +/// +/// This struct maintains a buffer to accumulate incoming data and parses it into SSE events +/// based on the SSE protocol. It handles fields like `event`, `data`, and `id` as defined +/// in the SSE specification. +#[derive(Debug)] +pub struct SseParser { + pub buffer: BytesMut, +} + +impl SseParser { + /// Creates a new `SseParser` with an empty buffer pre-allocated to a default capacity. + /// + /// The buffer is initialized with a capacity of `BUFFER_CAPACITY` to + /// optimize for typical SSE message sizes. + /// + /// # Returns + /// A new `SseParser` instance with an empty buffer. + pub fn new() -> Self { + Self { + buffer: BytesMut::with_capacity(BUFFER_CAPACITY), + } + } + + /// Processes a new chunk of bytes and parses it into a vector of `SseEvent`s. + /// + /// This method appends the incoming `bytes` to the internal buffer, splits it into + /// complete lines (delimited by `\n`), and parses each line according to the SSE + /// protocol. It supports `event`, `id`, and `data` fields, as well as comments + /// (lines starting with `:`). Empty lines are skipped, and incomplete lines remain + /// in the buffer for future processing. + /// + /// # Parameters + /// - `bytes`: The incoming chunk of bytes to parse. + /// + /// # Returns + /// A vector of `SseEvent`s parsed from the complete lines in the buffer. If no + /// complete events are found, an empty vector is returned. + pub fn process_new_chunk(&mut self, bytes: Bytes) -> Vec { + self.buffer.extend_from_slice(&bytes); + + // Collect complete lines (ending in \n)β€”keep ALL lines, including empty ones for \n\n detection + let mut lines = Vec::new(); + while let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') { + let line = self.buffer.split_to(pos + 1).freeze(); + lines.push(line); + } + + let mut events = Vec::new(); + let mut current_message_lines: Vec = Vec::new(); + + for line in lines { + current_message_lines.push(line); + + // Check if we've hit a double newline (end of message) + if current_message_lines.len() >= 2 + && current_message_lines + .last() + .is_some_and(|b| b.as_ref() == b"\n") + { + // Process the complete message (exclude the last empty lines for parsing) + let message_lines: Vec<_> = current_message_lines + .drain(..current_message_lines.len() - 1) + .filter(|l| l.as_ref() != b"\n") // Filter internal empties + .collect(); + + if let Some(event) = self.parse_sse_message(&message_lines) { + events.push(event); + } + } + } + + // Put back any incomplete message + if !current_message_lines.is_empty() { + self.buffer.clear(); + for line in current_message_lines { + self.buffer.extend_from_slice(&line); + } + } + + events + } + + fn parse_sse_message(&self, lines: &[Bytes]) -> Option { + let mut fields: HashMap = HashMap::new(); + let mut data_parts: Vec = Vec::new(); + + for line_bytes in lines { + let line_str = String::from_utf8_lossy(line_bytes); + + // Skip comments and empty lines + if line_str.is_empty() || line_str.starts_with(':') { + continue; + } + + let (key, value) = if let Some(value) = line_str.strip_prefix("data: ") { + ("data", value.trim_start().to_string()) + } else if let Some(value) = line_str.strip_prefix("event: ") { + ("event", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("id: ") { + ("id", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("retry: ") { + ("retry", value.trim().to_string()) + } else { + // Invalid line; skip + continue; + }; + + if key == "data" { + if !value.is_empty() { + data_parts.push(value); + } + } else { + fields.insert(key.to_string(), value); + } + } + + // Build data (concat multi-line data with \n) , should not occur in MCP tho + let data = if data_parts.is_empty() { + None + } else { + let full_data = data_parts.join("\n"); + Some(Bytes::copy_from_slice(full_data.as_bytes())) // Use copy_from_slice for efficiency + }; + + // Skip invalid message with no data + let data = data?; + + // Get event (default to None) + let event = fields.get("event").cloned(); + let id = fields.get("id").cloned(); + + Some(SseEvent { + event, + data: Some(data), + id, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn test_single_data_event() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + assert!(events[0].event.is_none()); + assert!(events[0].id.is_none()); + } + + #[test] + fn test_event_with_id_and_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("event: message\nid: 123\ndata: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_event_chunks_in_different_orders() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\nevent: message\nid: 123\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_comment_line_ignored() { + let mut parser = SseParser::new(); + let input = Bytes::from(": this is a comment\n\n"); + let events = parser.process_new_chunk(input); + assert_eq!(events.len(), 0); + } + + #[test] + fn test_event_with_empty_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("data:\n\n"); + let events = parser.process_new_chunk(input); + // Your parser skips data lines with empty content + assert_eq!(events.len(), 0); + } + + #[test] + fn test_partial_chunks() { + let mut parser = SseParser::new(); + + let part1 = Bytes::from("data: hello"); + let part2 = Bytes::from(" world\n\n"); + + let events1 = parser.process_new_chunk(part1); + assert_eq!(events1.len(), 0); // incomplete + + let events2 = parser.process_new_chunk(part2); + assert_eq!(events2.len(), 1); + assert_eq!( + events2[0].data.as_deref(), + Some(Bytes::from("hello world\n").as_ref()) + ); + } + + #[test] + fn test_malformed_lines() { + let mut parser = SseParser::new(); + let input = Bytes::from("something invalid\ndata: ok\n\n"); + + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("ok\n").as_ref()) + ); + } + + #[test] + fn test_multiple_events_in_one_chunk() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: first\n\ndata: second\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 2); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("first\n").as_ref()) + ); + assert_eq!( + events[1].data.as_deref(), + Some(Bytes::from("second\n").as_ref()) + ); + } +} diff --git a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs new file mode 100644 index 0000000..3362c71 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs @@ -0,0 +1,374 @@ +use super::CancellationToken; +use crate::error::{TransportError, TransportResult}; +use crate::utils::SseParser; +use crate::utils::{http_get, validate_response_type, ResponseType}; +use crate::{utils::http_post, MCP_SESSION_ID_HEADER}; +use crate::{EventId, MCP_LAST_EVENT_ID_HEADER}; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{Client, Response, StatusCode}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, RwLock}; +use tokio::time; +use tokio_stream::StreamExt; + +//-----------------------------------------------------------------------------------// +pub(crate) struct StreamableHttpStream { + /// HTTP client for making SSE requests + pub client: Client, + /// URL of the SSE endpoint + pub mcp_url: String, + /// Maximum number of retry attempts for failed connections + pub max_retries: usize, + /// Delay between retry attempts + pub retry_delay: Duration, + /// Sender for transmitting received data to the readable channel + pub read_tx: mpsc::Sender, + /// Session id will be received from the server in the http + pub session_id: Arc>>, +} + +impl StreamableHttpStream { + pub(crate) async fn run( + &mut self, + payload: String, + cancellation_token: &CancellationToken, + custom_headers: &Option, + ) -> TransportResult<()> { + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let session_id = self.session_id.read().await.clone(); + + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!( + "StreamableHttp cancelled before connection attempt {}", + payload + ); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + //TODO: simplify + let response = match http_post( + &self.client, + &self.mcp_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await + { + Ok(response) => { + // if session_id_clone.read().await.is_none() { + let session_id = response + .headers() + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let mut guard = self.session_id.write().await; + *guard = session_id; + response + } + + Err(error) => { + tracing::error!("Failed to connect to MCP endpoint: {error}"); + return Err(error); + } + }; + + // return if status code != 200 and no result is expected + if response.status() != StatusCode::OK { + return Ok(()); + } + + let response_type = validate_response_type(&response).await?; + + // Handle non-streaming JSON response + if response_type == ResponseType::Json { + return match response.bytes().await { + Ok(bytes) => { + // Send the message + self.read_tx.send(bytes).await.map_err(|_| { + tracing::error!("Readable stream closed, shutting down MCP task"); + TransportError::SendFailure( + "Failed to send message: channel closed or full".to_string(), + ) + })?; + + // Send the newline + self.read_tx + .send(Bytes::from_static(b"\n")) + .await + .map_err(|_| { + tracing::error!( + "Failed to send newline, channel may be closed or full" + ); + TransportError::SendFailure( + "Failed to send newline: channel closed or full".to_string(), + ) + })?; + + Ok(()) + } + Err(error) => Err(error.into()), + }; + } + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unlike SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + // break after receiving the message(s) + return Ok(()); + } + } + Err(error) => { + tracing::error!("Error reading stream: {error}"); + return Err(error.into()); + } + } + } + } + + pub(crate) async fn make_standalone_stream_connection( + &self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + last_event_id: Option, + ) -> TransportResult { + let mut retry_count = 0; + let session_id = self.session_id.read().await.clone(); + + let headers = if let Some(event_id) = last_event_id.as_ref() { + let mut headers = HeaderMap::new(); + if let Some(custom) = custom_headers { + headers.extend(custom.iter().map(|(k, v)| (k.clone(), v.clone()))); + } + if let Ok(event_id_value) = HeaderValue::from_str(event_id) { + headers.insert(MCP_LAST_EVENT_ID_HEADER, event_id_value); + } + &Some(headers) + } else { + custom_headers + }; + + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + match http_get( + &self.client, + &self.mcp_url, + session_id.as_ref(), + headers.as_ref(), + ) + .await + { + Ok(response) => { + let is_event_stream = validate_response_type(&response) + .await + .is_ok_and(|response_type| response_type == ResponseType::EventStream); + + if !is_event_stream { + let message = + "SSE stream response returned an unexpected Content-Type.".to_string(); + tracing::warn!("{message}"); + return Err(TransportError::FailedToOpenSSEStream(message)); + } + + return Ok(response); + } + + Err(error) => { + match error { + crate::error::TransportError::HttpConnection(_) => { + // A reqwest::Error happened, we do not return ans instead retry the operation + } + crate::error::TransportError::Http(status_code) => match status_code { + StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => { + return Err(crate::error::TransportError::FailedToOpenSSEStream( + format!("Not supported (code: {status_code})"), + )); + } + other => { + tracing::warn!( + "Failed to open SSE stream: {error} (code: {other})" + ); + } + }, + error => { + return Err(error); // return the error where the retry wont help + } + } + + if retry_count >= self.max_retries { + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error); + } + retry_count += 1; + time::sleep(self.retry_delay).await; + continue; + } + }; + } + } + + pub(crate) async fn run_standalone( + &mut self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + response: Response, + ) -> TransportResult<()> { + let mut retry_count = 0; + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let mut response = Some(response); + + // Main loop for reconnection attempts + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::debug!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + // use initially passed response, otherwise try to make a new sse connection + let response = match response.take() { + Some(response) => response, + None => { + tracing::debug!( + "Reconnecting to SSE stream... (try {} of {})", + retry_count, + self.max_retries + ); + self.make_standalone_stream_connection( + cancellation_token, + custom_headers, + _last_event_id.clone(), + ) + .await? + } + }; + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unlike SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + } + retry_count = 0; // Reset retry count on successful chunk + } + Err(error) => { + if retry_count >= self.max_retries { + tracing::error!("Error reading stream: {error}"); + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error.into()); + } + + tracing::debug!( + "The standalone SSE stream encountered an error: '{}'", + error + ); + retry_count += 1; + time::sleep(self.retry_delay).await; + break; // Break inner loop to reconnect + } + } + } + } + } +} diff --git a/crates/rust-mcp-transport/src/utils/time_utils.rs b/crates/rust-mcp-transport/src/utils/time_utils.rs new file mode 100644 index 0000000..25c4f5d --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/time_utils.rs @@ -0,0 +1,8 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn current_timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Invalid time") + .as_nanos() +} diff --git a/crates/rust-mcp-transport/tests/check_imports.rs b/crates/rust-mcp-transport/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-transport/tests/check_imports.rs +++ b/crates/rust-mcp-transport/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/development.md b/development.md index e3673cc..e17dd17 100644 --- a/development.md +++ b/development.md @@ -33,14 +33,14 @@ Build and run instructions are available in their respective README.md files. You can run examples by passing the example project name to Cargo using the `-p` argument, like this: ```sh -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` -You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server`: +You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server-stdio`: ```sh -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` ## Code Formatting diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 358b1b4..418fd66 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -40,7 +40,7 @@ edition = "2024" [dependencies] async-trait = "0.1" -rust-mcp-sdk = "0.4" +rust-mcp-sdk = "0.7" serde = "1.0" serde_json = "1.0" tokio = "1.4" @@ -72,11 +72,10 @@ Create a new module in the project called `tools.rs` and include the definitions //src/tools.rs use rust_mcp_sdk::schema::{CallToolResult, TextContent, schema_utils::CallToolError}; use rust_mcp_sdk::{ - macros::{mcp_tool, JsonSchema}, + macros::{JsonSchema, mcp_tool}, tool_box, }; - //****************// // SayHelloTool // //****************// @@ -93,7 +92,9 @@ pub struct SayHelloTool { impl SayHelloTool { pub fn call_tool(&self) -> Result { let hello_message = format!("Hello, {}!", self.name); - Ok(CallToolResult::text_content( vec![TextContent::from(hello_message)] )) + Ok(CallToolResult::text_content(vec![TextContent::from( + hello_message, + )])) } } @@ -112,7 +113,9 @@ pub struct SayGoodbyeTool { impl SayGoodbyeTool { pub fn call_tool(&self) -> Result { let hello_message = format!("Goodbye, {}!", self.name); - Ok(CallToolResult::text_content( vec![TextContent::from(hello_message)] )) + Ok(CallToolResult::text_content(vec![TextContent::from( + hello_message, + )])) } } @@ -142,12 +145,14 @@ Here is the code for `handler.rs` : ```rs // src/handler.rs +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ - schema_utils::CallToolError, CallToolRequest, CallToolResult, RpcError, - ListToolsRequest, ListToolsResult, + CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, RpcError, + schema_utils::CallToolError, }; -use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use rust_mcp_sdk::{McpServer, mcp_server::ServerHandler}; use crate::tools::GreetingTools; @@ -160,7 +165,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, _request: ListToolsRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -173,7 +178,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = @@ -207,14 +212,11 @@ mod handler; mod tools; use handler::MyServerHandler; use rust_mcp_sdk::schema::{ - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, - LATEST_PROTOCOL_VERSION, + Implementation, InitializeResult, LATEST_PROTOCOL_VERSION, ServerCapabilities, + ServerCapabilitiesTools, }; - use rust_mcp_sdk::{ - error::SdkResult, - mcp_server::{server_runtime, ServerRuntime}, - McpServer, StdioTransport, TransportOptions, + McpServer, StdioTransport, TransportOptions, error::SdkResult, mcp_server::server_runtime, }; #[tokio::main] @@ -244,7 +246,7 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; //create the MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server = server_runtime::create_server(server_details, transport, handler); // Start the server server.start().await diff --git a/examples/hello-world-mcp-server-core/.gitignore b/examples/hello-world-mcp-server-stdio-core/.gitignore similarity index 100% rename from examples/hello-world-mcp-server-core/.gitignore rename to examples/hello-world-mcp-server-stdio-core/.gitignore diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-stdio-core/Cargo.toml similarity index 83% rename from examples/hello-world-mcp-server-core/Cargo.toml rename to examples/hello-world-mcp-server-stdio-core/Cargo.toml index a38a0b9..f37d4c4 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server-core" -version = "0.1.16" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.20" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server-core/README.md b/examples/hello-world-mcp-server-stdio-core/README.md similarity index 81% rename from examples/hello-world-mcp-server-core/README.md rename to examples/hello-world-mcp-server-stdio-core/README.md index af9d703..cf57884 100644 --- a/examples/hello-world-mcp-server-core/README.md +++ b/examples/hello-world-mcp-server-stdio-core/README.md @@ -23,14 +23,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server-core --release +cargo build -p hello-world-mcp-server-stdio-core --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-core` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio-core` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-core +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio-core ``` ``` @@ -41,4 +41,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-stdio-core/src/handler.rs similarity index 96% rename from examples/hello-world-mcp-server-core/src/handler.rs rename to examples/hello-world-mcp-server-stdio-core/src/handler.rs index fcde15e..acf55ea 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -90,7 +92,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -98,8 +100,8 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, - _: &dyn McpServer, + error: &RpcError, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-mcp-server-core/src/main.rs b/examples/hello-world-mcp-server-stdio-core/src/main.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/main.rs rename to examples/hello-world-mcp-server-stdio-core/src/main.rs diff --git a/examples/hello-world-mcp-server-core/src/tools.rs b/examples/hello-world-mcp-server-stdio-core/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/tools.rs rename to examples/hello-world-mcp-server-stdio-core/src/tools.rs diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server-stdio/Cargo.toml similarity index 85% rename from examples/hello-world-mcp-server/Cargo.toml rename to examples/hello-world-mcp-server-stdio/Cargo.toml index 7fc7d0f..1947dce 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server" -version = "0.1.25" +name = "hello-world-mcp-server-stdio" +version = "0.1.29" edition = "2021" publish = false license = "MIT" @@ -10,8 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", - "hyper-server", - "ssl", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server/README.md b/examples/hello-world-mcp-server-stdio/README.md similarity index 84% rename from examples/hello-world-mcp-server/README.md rename to examples/hello-world-mcp-server-stdio/README.md index 33a62af..9e0bdda 100644 --- a/examples/hello-world-mcp-server/README.md +++ b/examples/hello-world-mcp-server-stdio/README.md @@ -22,14 +22,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio ``` ``` @@ -40,4 +40,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server/src/handler.rs b/examples/hello-world-mcp-server-stdio/src/handler.rs similarity index 94% rename from examples/hello-world-mcp-server/src/handler.rs rename to examples/hello-world-mcp-server-stdio/src/handler.rs index d9741a0..47925a0 100644 --- a/examples/hello-world-mcp-server/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio/src/handler.rs @@ -4,6 +4,7 @@ use rust_mcp_sdk::schema::{ ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use std::sync::Arc; use crate::tools::GreetingTools; @@ -20,7 +21,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +34,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server-stdio/src/main.rs similarity index 92% rename from examples/hello-world-mcp-server/src/main.rs rename to examples/hello-world-mcp-server-stdio/src/main.rs index 00ca6a7..98ff6f0 100644 --- a/examples/hello-world-mcp-server/src/main.rs +++ b/examples/hello-world-mcp-server-stdio/src/main.rs @@ -1,6 +1,8 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, @@ -40,7 +42,8 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server: Arc = + server_runtime::create_server(server_details, transport, handler); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-mcp-server/src/tools.rs b/examples/hello-world-mcp-server-stdio/src/tools.rs similarity index 73% rename from examples/hello-world-mcp-server/src/tools.rs rename to examples/hello-world-mcp-server-stdio/src/tools.rs index 15d6a8b..f6b1edb 100644 --- a/examples/hello-world-mcp-server/src/tools.rs +++ b/examples/hello-world-mcp-server-stdio/src/tools.rs @@ -1,8 +1,29 @@ use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult, TextContent}; -use rust_mcp_sdk::{ - macros::{mcp_tool, JsonSchema}, - tool_box, -}; +use rust_mcp_sdk::{macros::mcp_tool, tool_box}; + +use rust_mcp_sdk::macros::JsonSchema; +use rust_mcp_sdk::schema::RpcError; +use std::str::FromStr; + +// Simple enum with FromStr trait implemented +#[derive(JsonSchema, Debug)] +pub enum Colors { + #[json_schema(title = "Green Color")] + Green, + #[json_schema(title = "Red Color")] + Red, +} +impl FromStr for Colors { + type Err = RpcError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "green" => Ok(Colors::Green), + "red" => Ok(Colors::Red), + _ => Err(RpcError::parse_error().with_message("Invalid color".to_string())), + } + } +} //****************// // SayHelloTool // diff --git a/examples/hello-world-server-core-streamable-http/.gitignore b/examples/hello-world-server-streamable-http-core/.gitignore similarity index 100% rename from examples/hello-world-server-core-streamable-http/.gitignore rename to examples/hello-world-server-streamable-http-core/.gitignore diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http-core/Cargo.toml similarity index 84% rename from examples/hello-world-server-core-streamable-http/Cargo.toml rename to examples/hello-world-server-streamable-http-core/Cargo.toml index 84dfd70..85e470a 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-server-core-streamable-http" -version = "0.1.16" +name = "hello-world-server-streamable-http-core" +version = "0.1.20" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-core-streamable-http/README.md b/examples/hello-world-server-streamable-http-core/README.md similarity index 95% rename from examples/hello-world-server-core-streamable-http/README.md rename to examples/hello-world-server-streamable-http-core/README.md index cd37623..49af2c2 100644 --- a/examples/hello-world-server-core-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http-core/README.md @@ -37,7 +37,7 @@ cd rust-mcp-sdk 2. Build and start the server: ```bash -cargo run -p hello-world-server-core-streamable-http --release +cargo run -p hello-world-server-streamable-http-core --release ``` By default, both the Streamable HTTP and SSE endpoints are displayed in the terminal: @@ -65,4 +65,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http-core/src/handler.rs similarity index 96% rename from examples/hello-world-server-core-streamable-http/src/handler.rs rename to examples/hello-world-server-streamable-http-core/src/handler.rs index 53f884c..7941075 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -95,7 +97,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -103,8 +105,8 @@ impl ServerHandlerCore for MyServerHandler { // Process incoming client errors async fn handle_error( &self, - error: RpcError, - _: &dyn McpServer, + error: &RpcError, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-server-core-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs similarity index 91% rename from examples/hello-world-server-core-streamable-http/src/main.rs rename to examples/hello-world-server-streamable-http-core/src/main.rs index 7b41c70..81a6ae5 100644 --- a/examples/hello-world-server-core-streamable-http/src/main.rs +++ b/examples/hello-world-server-streamable-http-core/src/main.rs @@ -1,7 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, @@ -48,6 +51,7 @@ async fn main() -> SdkResult<()> { handler, HyperServerOptions { sse_support: true, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/examples/hello-world-server-core-streamable-http/src/tools.rs b/examples/hello-world-server-streamable-http-core/src/tools.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/tools.rs rename to examples/hello-world-server-streamable-http-core/src/tools.rs diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index 6776b0c..61d080f 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hello-world-server-streamable-http" -version = "0.1.25" +version = "0.1.32" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-streamable-http/README.md b/examples/hello-world-server-streamable-http/README.md index ac56a86..7e3f3b6 100644 --- a/examples/hello-world-server-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http/README.md @@ -66,4 +66,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index b8ce355..3939d86 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,12 +1,11 @@ +use crate::tools::GreetingTools; use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; - -use crate::tools::GreetingTools; - +use std::sync::Arc; // Custom Handler to handle MCP Messages pub struct MyServerHandler; @@ -20,7 +19,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +32,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = @@ -45,6 +44,4 @@ impl ServerHandler for MyServerHandler { GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool.call_tool(), } } - - async fn on_server_started(&self, runtime: &dyn McpServer) {} } diff --git a/examples/hello-world-server-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http/src/main.rs index cd8c658..3923a6d 100644 --- a/examples/hello-world-server-streamable-http/src/main.rs +++ b/examples/hello-world-server-streamable-http/src/main.rs @@ -1,8 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; use std::time::Duration; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions}; use handler::MyServerHandler; @@ -57,6 +59,7 @@ async fn main() -> SdkResult<()> { HyperServerOptions { host: "127.0.0.1".to_string(), ping_interval: Duration::from_secs(5), + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-sse-core/Cargo.toml similarity index 88% rename from examples/simple-mcp-client-core-sse/Cargo.toml rename to examples/simple-mcp-client-sse-core/Cargo.toml index 3cbd9df..05654fc 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core-sse" -version = "0.1.16" +name = "simple-mcp-client-sse-core" +version = "0.1.20" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "sse", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core-sse/README.md b/examples/simple-mcp-client-sse-core/README.md similarity index 97% rename from examples/simple-mcp-client-core-sse/README.md rename to examples/simple-mcp-client-sse-core/README.md index e7e10d2..a0852fb 100644 --- a/examples/simple-mcp-client-core-sse/README.md +++ b/examples/simple-mcp-client-sse-core/README.md @@ -32,7 +32,7 @@ npx @modelcontextprotocol/server-everything sse 2. Open a new terminal and run the project with: ```bash -cargo run -p simple-mcp-client-core-sse +cargo run -p simple-mcp-client-sse-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-sse-core/src/handler.rs similarity index 79% rename from examples/simple-mcp-client-core/src/handler.rs rename to examples/simple-mcp-client-sse-core/src/handler.rs index a1a95e4..ab86e9e 100644 --- a/examples/simple-mcp-client-core/src/handler.rs +++ b/examples/simple-mcp-client-sse-core/src/handler.rs @@ -41,16 +41,30 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_notification( &self, - _notification: NotificationFromServer, + notification: NotificationFromServer, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { - Err(RpcError::internal_error() - .with_message("handle_notification() Not implemented".to_string())) + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) } async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) diff --git a/examples/simple-mcp-client-core-sse/src/inquiry_utils.rs b/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/inquiry_utils.rs rename to examples/simple-mcp-client-sse-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core-sse/src/main.rs b/examples/simple-mcp-client-sse-core/src/main.rs similarity index 99% rename from examples/simple-mcp-client-core-sse/src/main.rs rename to examples/simple-mcp-client-sse-core/src/main.rs index 459f9ba..be8279b 100644 --- a/examples/simple-mcp-client-core-sse/src/main.rs +++ b/examples/simple-mcp-client-sse-core/src/main.rs @@ -44,6 +44,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime_core::create_client(client_details, transport, handler); // STEP 5: start the MCP client diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 60dd69c..0720afe 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "simple-mcp-client-sse" -version = "0.1.16" +version = "0.1.23" edition = "2021" publish = false license = "MIT" @@ -9,6 +9,8 @@ license = "MIT" [dependencies] rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", + "sse", + "streamable-http", "macros", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-sse/src/main.rs b/examples/simple-mcp-client-sse/src/main.rs index ce8850a..0a76caa 100644 --- a/examples/simple-mcp-client-sse/src/main.rs +++ b/examples/simple-mcp-client-sse/src/main.rs @@ -15,7 +15,9 @@ use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -const MCP_SERVER_URL: &str = "http://localhost:3001/sse"; +// Connect to a server started with the following command: +// npx @modelcontextprotocol/server-everything sse +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/sse"; #[tokio::main] async fn main() -> SdkResult<()> { @@ -44,6 +46,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime::create_client(client_details, transport, handler); // STEP 5: start the MCP client @@ -57,6 +60,7 @@ async fn main() -> SdkResult<()> { let utils = InquiryUtils { client: Arc::clone(&client), }; + // Display server information (name and version) utils.print_server_info(); @@ -78,8 +82,11 @@ async fn main() -> SdkResult<()> { // Call add tool, and print the result utils.call_add_tool(100, 25).await?; - // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + // // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-stdio-core/Cargo.toml similarity index 86% rename from examples/simple-mcp-client-core/Cargo.toml rename to examples/simple-mcp-client-stdio-core/Cargo.toml index d0288f9..f7dc568 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core" -version = "0.1.25" +name = "simple-mcp-client-stdio-core" +version = "0.1.29" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core/README.md b/examples/simple-mcp-client-stdio-core/README.md similarity index 97% rename from examples/simple-mcp-client-core/README.md rename to examples/simple-mcp-client-stdio-core/README.md index 52d8074..f3258aa 100644 --- a/examples/simple-mcp-client-core/README.md +++ b/examples/simple-mcp-client-stdio-core/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client-core +cargo run -p simple-mcp-client-stdio-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-stdio-core/src/handler.rs similarity index 98% rename from examples/simple-mcp-client-core-sse/src/handler.rs rename to examples/simple-mcp-client-stdio-core/src/handler.rs index a1a95e4..bd5e4fe 100644 --- a/examples/simple-mcp-client-core-sse/src/handler.rs +++ b/examples/simple-mcp-client-stdio-core/src/handler.rs @@ -50,7 +50,7 @@ impl ClientHandlerCore for MyClientHandler { async fn handle_error( &self, - _error: RpcError, + _error: &RpcError, _runtime: &dyn McpClient, ) -> std::result::Result<(), RpcError> { Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) diff --git a/examples/simple-mcp-client-core/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core/src/main.rs b/examples/simple-mcp-client-stdio-core/src/main.rs similarity index 100% rename from examples/simple-mcp-client-core/src/main.rs rename to examples/simple-mcp-client-stdio-core/src/main.rs diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client-stdio/Cargo.toml similarity index 87% rename from examples/simple-mcp-client/Cargo.toml rename to examples/simple-mcp-client-stdio/Cargo.toml index cdfa228..7bbd890 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client" -version = "0.1.25" +name = "simple-mcp-client-stdio" +version = "0.1.29" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/simple-mcp-client/README.md b/examples/simple-mcp-client-stdio/README.md similarity index 97% rename from examples/simple-mcp-client/README.md rename to examples/simple-mcp-client-stdio/README.md index c56a933..be17f02 100644 --- a/examples/simple-mcp-client/README.md +++ b/examples/simple-mcp-client-stdio/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client/src/handler.rs b/examples/simple-mcp-client-stdio/src/handler.rs similarity index 100% rename from examples/simple-mcp-client/src/handler.rs rename to examples/simple-mcp-client-stdio/src/handler.rs diff --git a/examples/simple-mcp-client/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client/src/main.rs b/examples/simple-mcp-client-stdio/src/main.rs similarity index 100% rename from examples/simple-mcp-client/src/main.rs rename to examples/simple-mcp-client-stdio/src/main.rs diff --git a/examples/simple-mcp-client-streamable-http-core/Cargo.toml b/examples/simple-mcp-client-streamable-http-core/Cargo.toml new file mode 100644 index 0000000..c8b3464 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http-core" +version = "0.1.1" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "macros", + "streamable-http", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http-core/README.md b/examples/simple-mcp-client-streamable-http-core/README.md new file mode 100644 index 0000000..a0852fb --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client Core (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse-core +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http-core/src/handler.rs b/examples/simple-mcp-client-streamable-http-core/src/handler.rs new file mode 100644 index 0000000..ab86e9e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/handler.rs @@ -0,0 +1,72 @@ +use async_trait::async_trait; +use rust_mcp_sdk::schema::{ + self, + schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, + RpcError, ServerRequest, +}; +use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; +pub struct MyClientHandler; + +// To check out a list of all the methods in the trait that you can override, take a look at +// https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs + +#[async_trait] +impl ClientHandlerCore for MyClientHandler { + async fn handle_request( + &self, + request: RequestFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result { + match request { + RequestFromServer::ServerRequest(server_request) => match server_request { + ServerRequest::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + ServerRequest::CreateMessageRequest(_create_message_request) => { + Err(RpcError::internal_error().with_message( + "CreateMessageRequest handler is not implemented".to_string(), + )) + } + ServerRequest::ListRootsRequest(_list_roots_request) => { + Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())) + } + ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + }, + RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + .with_message("CustomRequest handler is not implemented".to_string())), + } + } + + async fn handle_notification( + &self, + notification: NotificationFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) + } + + async fn handle_error( + &self, + _error: &RpcError, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/main.rs b/examples/simple-mcp-client-streamable-http-core/src/main.rs new file mode 100644 index 0000000..e1a5849 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/main.rs @@ -0,0 +1,95 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use inquiry_utils::InquiryUtils; +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime_core; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +// Assuming @modelcontextprotocol/server-everything is launched with streamableHttp argument and listening on port 3001 +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-core-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (Core,SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client + let client = + client_runtime_core::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + utils.client.set_logging_level(LoggingLevel::Debug).await?; + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} diff --git a/examples/simple-mcp-client-streamable-http/Cargo.toml b/examples/simple-mcp-client-streamable-http/Cargo.toml new file mode 100644 index 0000000..bf2827a --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http" +version = "0.1.1" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "streamable-http", + "macros", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http/README.md b/examples/simple-mcp-client-streamable-http/README.md new file mode 100644 index 0000000..5b4488e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http/src/handler.rs b/examples/simple-mcp-client-streamable-http/src/handler.rs new file mode 100644 index 0000000..19360f6 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/handler.rs @@ -0,0 +1,10 @@ +use async_trait::async_trait; +use rust_mcp_sdk::mcp_client::ClientHandler; + +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} diff --git a/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http/src/main.rs b/examples/simple-mcp-client-streamable-http/src/main.rs new file mode 100644 index 0000000..95d4d8d --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/main.rs @@ -0,0 +1,99 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +use crate::inquiry_utils::InquiryUtils; + +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +}