diff --git a/Directory.Packages.props b/Directory.Packages.props index 6f5e8c2e2..6e61a29f9 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -68,6 +68,7 @@ + diff --git a/README.md b/README.md index ed89bfa4a..6ad5aa89a 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ McpServerOptions options = new() Tools = new ToolsCapability() { ListToolsHandler = (request, cancellationToken) => - Task.FromResult(new ListToolsResult() + ValueTask.FromResult(new ListToolsResult() { Tools = [ @@ -202,7 +202,7 @@ McpServerOptions options = new() throw new McpException("Missing required argument 'message'"); } - return Task.FromResult(new CallToolResponse() + return ValueTask.FromResult(new CallToolResponse() { Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] }); diff --git a/samples/TestServerWithHosting/TestServerWithHosting.csproj b/samples/TestServerWithHosting/TestServerWithHosting.csproj index 9ddb6190d..2466a2811 100644 --- a/samples/TestServerWithHosting/TestServerWithHosting.csproj +++ b/samples/TestServerWithHosting/TestServerWithHosting.csproj @@ -5,21 +5,16 @@ net9.0;net8.0;net472 enable enable - + true + - diff --git a/src/Directory.Build.props b/src/Directory.Build.props index b15b0cd38..d75e650cd 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -6,7 +6,7 @@ https://github.com/modelcontextprotocol/csharp-sdk git 0.2.0 - preview.1 + preview.2 ModelContextProtocolOfficial © Anthropic and Contributors. ModelContextProtocol;mcp;ai;llm diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index ec545c990..6077efa10 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; @@ -28,9 +29,6 @@ internal sealed class StreamableHttpHandler( { private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); - private static readonly MediaTypeHeaderValue s_applicationJsonMediaType = new("application/json"); - private static readonly MediaTypeHeaderValue s_textEventStreamMediaType = new("text/event-stream"); - public ConcurrentDictionary> Sessions { get; } = new(StringComparer.Ordinal); public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; @@ -43,8 +41,8 @@ public async Task HandlePostRequestAsync(HttpContext context) // ASP.NET Core Minimal APIs mostly try to stay out of the business of response content negotiation, // so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests, // but it's probably good to at least start out trying to be strict. - var acceptHeaders = context.Request.GetTypedHeaders().Accept; - if (!acceptHeaders.Contains(s_applicationJsonMediaType) || !acceptHeaders.Contains(s_textEventStreamMediaType)) + var typedHeaders = context.Request.GetTypedHeaders(); + if (!typedHeaders.Accept.Any(MatchesApplicationJsonMediaType) || !typedHeaders.Accept.Any(MatchesTextEventStreamMediaType)) { await WriteJsonRpcErrorAsync(context, "Not Acceptable: Client must accept both application/json and text/event-stream", @@ -85,8 +83,7 @@ await WriteJsonRpcErrorAsync(context, public async Task HandleGetRequestAsync(HttpContext context) { - var acceptHeaders = context.Request.GetTypedHeaders().Accept; - if (!acceptHeaders.Contains(s_textEventStreamMediaType)) + if (!context.Request.GetTypedHeaders().Accept.Any(MatchesTextEventStreamMediaType)) { await WriteJsonRpcErrorAsync(context, "Not Acceptable: Client must accept text/event-stream", @@ -331,6 +328,12 @@ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptHeaderValue) + => acceptHeaderValue.MatchesMediaType("application/json"); + + private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue) + => acceptHeaderValue.MatchesMediaType("text/event-stream"); + private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe { public PipeReader Input => context.Request.BodyReader; diff --git a/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs new file mode 100644 index 000000000..50601f666 --- /dev/null +++ b/src/ModelContextProtocol/Client/AutoDetectingClientSessionTransport.cs @@ -0,0 +1,143 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Net; +using System.Threading.Channels; + +namespace ModelContextProtocol.Client; + +/// +/// A transport that automatically detects whether to use Streamable HTTP or SSE transport +/// by trying Streamable HTTP first and falling back to SSE if that fails. +/// +internal sealed partial class AutoDetectingClientSessionTransport : ITransport +{ + private readonly SseClientTransportOptions _options; + private readonly HttpClient _httpClient; + private readonly ILoggerFactory? _loggerFactory; + private readonly ILogger _logger; + private readonly string _name; + private readonly Channel _messageChannel; + + public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) + { + Throw.IfNull(transportOptions); + Throw.IfNull(httpClient); + + _options = transportOptions; + _httpClient = httpClient; + _loggerFactory = loggerFactory; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _name = endpointName; + + // Same as TransportBase.cs. + _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + { + SingleReader = true, + SingleWriter = false, + }); + } + + /// + /// Returns the active transport (either StreamableHttp or SSE) + /// + internal ITransport? ActiveTransport { get; private set; } + + public ChannelReader MessageReader => _messageChannel.Reader; + + /// + public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (ActiveTransport is null) + { + return InitializeAsync(message, cancellationToken); + } + + return ActiveTransport.SendMessageAsync(message, cancellationToken); + } + + private async Task InitializeAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + // Try StreamableHttp first + var streamableHttpTransport = new StreamableHttpClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory); + + try + { + LogAttemptingStreamableHttp(_name); + using var response = await streamableHttpTransport.SendHttpRequestAsync(message, cancellationToken).ConfigureAwait(false); + + if (response.IsSuccessStatusCode) + { + LogUsingStreamableHttp(_name); + ActiveTransport = streamableHttpTransport; + } + else + { + // If the status code is not success, fall back to SSE + LogStreamableHttpFailed(_name, response.StatusCode); + + await streamableHttpTransport.DisposeAsync().ConfigureAwait(false); + await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false); + } + } + catch + { + // If nothing threw inside the try block, we've either set streamableHttpTransport as the + // ActiveTransport, or else we will have disposed it in the !IsSuccessStatusCode else block. + await streamableHttpTransport.DisposeAsync().ConfigureAwait(false); + throw; + } + } + + private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + var sseTransport = new SseClientSessionTransport(_name, _options, _httpClient, _messageChannel, _loggerFactory); + + try + { + LogAttemptingSSE(_name); + await sseTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + await sseTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + LogUsingSSE(_name); + ActiveTransport = sseTransport; + } + catch + { + await sseTransport.DisposeAsync().ConfigureAwait(false); + throw; + } + } + + public async ValueTask DisposeAsync() + { + try + { + if (ActiveTransport is not null) + { + await ActiveTransport.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + // In the majority of cases, either the Streamable HTTP transport or SSE transport has completed the channel by now. + // However, this may not be the case if HttpClient throws during the initial request due to misconfiguration. + _messageChannel.Writer.TryComplete(); + } + } + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} attempting to connect using Streamable HTTP transport.")] + private partial void LogAttemptingStreamableHttp(string endpointName); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} streamable HTTP transport failed with status code {StatusCode}, falling back to SSE transport.")] + private partial void LogStreamableHttpFailed(string endpointName, HttpStatusCode statusCode); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} using Streamable HTTP transport.")] + private partial void LogUsingStreamableHttp(string endpointName); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} attempting to connect using SSE transport.")] + private partial void LogAttemptingSSE(string endpointName); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} using SSE transport.")] + private partial void LogUsingSSE(string endpointName); +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/HttpTransportMode.cs b/src/ModelContextProtocol/Client/HttpTransportMode.cs new file mode 100644 index 000000000..f2d46c302 --- /dev/null +++ b/src/ModelContextProtocol/Client/HttpTransportMode.cs @@ -0,0 +1,23 @@ +namespace ModelContextProtocol.Client; + +/// +/// Specifies the transport mode for HTTP client connections. +/// +public enum HttpTransportMode +{ + /// + /// Automatically detect the appropriate transport by trying Streamable HTTP first, then falling back to SSE if that fails. + /// This is the recommended mode for maximum compatibility. + /// + AutoDetect, + + /// + /// Use only the Streamable HTTP transport. + /// + StreamableHttp, + + /// + /// Use only the HTTP with SSE transport. + /// + Sse +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol/Client/SseClientSessionTransport.cs index 78997b5e8..fd2466eaf 100644 --- a/src/ModelContextProtocol/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Client/SseClientSessionTransport.cs @@ -6,6 +6,7 @@ using System.Net.ServerSentEvents; using System.Text; using System.Text.Json; +using System.Threading.Channels; namespace ModelContextProtocol.Client; @@ -24,15 +25,16 @@ internal sealed partial class SseClientSessionTransport : TransportBase private readonly TaskCompletionSource _connectionEstablished; /// - /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server. + /// SSE transport for a single session. Unlike stdio it does not launch a process, but connects to an existing server. /// The HTTP server can be local or remote, and must support the SSE protocol. /// - /// Configuration options for the transport. - /// The HTTP client instance used for requests. - /// Logger factory for creating loggers. - /// The endpoint name used for logging purposes. - public SseClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) - : base(endpointName, loggerFactory) + public SseClientSessionTransport( + string endpointName, + SseClientTransportOptions transportOptions, + HttpClient httpClient, + Channel? messageChannel, + ILoggerFactory? loggerFactory) + : base(endpointName, messageChannel, loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); @@ -92,26 +94,18 @@ public override async Task SendMessageAsync( StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders); var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); - response.EnsureSuccessStatusCode(); - - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - - if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) - { - LogAcceptedPost(Name, messageId); - } - else + if (!response.IsSuccessStatusCode) { if (_logger.IsEnabled(LogLevel.Trace)) { - LogRejectedPostSensitive(Name, messageId, responseContent); + LogRejectedPostSensitive(Name, messageId, await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false)); } else { LogRejectedPost(Name, messageId); } - throw new InvalidOperationException("Failed to send message"); + response.EnsureSuccessStatusCode(); } } diff --git a/src/ModelContextProtocol/Client/SseClientTransport.cs b/src/ModelContextProtocol/Client/SseClientTransport.cs index df1cdac6c..57789c1cc 100644 --- a/src/ModelContextProtocol/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol/Client/SseClientTransport.cs @@ -4,11 +4,11 @@ namespace ModelContextProtocol.Client; /// -/// Provides an over HTTP using the Server-Sent Events (SSE) protocol. +/// Provides an over HTTP using the Server-Sent Events (SSE) or Streamable HTTP protocol. /// /// -/// This transport connects to an MCP server over HTTP using SSE, -/// allowing for real-time server-to-client communication with a standard HTTP request. +/// This transport connects to an MCP server over HTTP using SSE or Streamable HTTP, +/// allowing for real-time server-to-client communication with a standard HTTP requests. /// Unlike the , this transport connects to an existing server /// rather than launching a new process. /// @@ -36,7 +36,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFac /// The HTTP client instance used for requests. /// Logger factory for creating loggers used for diagnostic output during transport operations. /// - /// to dispose of when the transport is disposed; + /// to dispose of when the transport is disposed; /// if the caller is retaining ownership of the 's lifetime. /// public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false) @@ -57,12 +57,22 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - if (_options.UseStreamableHttp) + switch (_options.TransportMode) { - return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + case HttpTransportMode.AutoDetect: + return new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + case HttpTransportMode.StreamableHttp: + return new StreamableHttpClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory); + case HttpTransportMode.Sse: + return await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false); + default: + throw new InvalidOperationException($"Unsupported transport mode: {_options.TransportMode}"); } + } - var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name); + private async Task ConnectSseTransportAsync(CancellationToken cancellationToken) + { + var sessionTransport = new SseClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory); try { diff --git a/src/ModelContextProtocol/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol/Client/SseClientTransportOptions.cs index f67f6f07d..8843fca80 100644 --- a/src/ModelContextProtocol/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Client/SseClientTransportOptions.cs @@ -31,11 +31,19 @@ public required Uri Endpoint } /// - /// Gets or sets a value indicating whether to use "Streamable HTTP" for the transport rather than "HTTP with SSE". Defaults to false. + /// Gets or sets the transport mode to use for the connection. Defaults to . + /// + /// + /// + /// When set to (the default), the client will first attempt to use + /// Streamable HTTP transport and automatically fall back to SSE transport if the server doesn't support it. + /// + /// /// Streamable HTTP transport specification. /// HTTP with SSE transport specification. - /// - public bool UseStreamableHttp { get; init; } + /// + /// + public HttpTransportMode TransportMode { get; init; } = HttpTransportMode.AutoDetect; /// /// Gets a transport identifier used for logging purposes. diff --git a/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs index 3330f4def..e35e2b18e 100644 --- a/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Client/StreamClientSessionTransport.cs @@ -17,7 +17,7 @@ internal class StreamClientSessionTransport : TransportBase /// Initializes a new instance of the class. /// /// - /// The text writer connected to the server's input stream. + /// The text writer connected to the server's input stream. /// Messages written to this writer will be sent to the server. /// /// @@ -41,17 +41,17 @@ public StreamClientSessionTransport( _serverOutput = serverOutput; _serverInput = serverInput; + SetConnected(); + // Start reading messages in the background. We use the rarer pattern of new Task + Start // in order to ensure that the body of the task will always see _readTask initialized. // It is then able to reliably null it out on completion. var readTask = new Task( - thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token), + thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token), this, TaskCreationOptions.DenyChildAttach); _readTask = readTask.Unwrap(); readTask.Start(); - - SetConnected(); } /// @@ -80,7 +80,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation } /// - public override ValueTask DisposeAsync() => + public override ValueTask DisposeAsync() => CleanupAsync(cancellationToken: CancellationToken.None); private async Task ReadMessagesAsync(CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs index f90349001..78f99e20d 100644 --- a/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol/Client/StreamableHttpClientSessionTransport.cs @@ -4,6 +4,8 @@ using System.Net.ServerSentEvents; using System.Text.Json; using ModelContextProtocol.Protocol; +using System.Threading.Channels; + #if NET using System.Net.Http.Json; #else @@ -28,8 +30,13 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private string? _mcpSessionId; private Task? _getReceiveTask; - public StreamableHttpClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) - : base(endpointName, loggerFactory) + public StreamableHttpClientSessionTransport( + string endpointName, + SseClientTransportOptions transportOptions, + HttpClient httpClient, + Channel? messageChannel, + ILoggerFactory? loggerFactory) + : base(endpointName, messageChannel, loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); @@ -46,9 +53,15 @@ public StreamableHttpClientSessionTransport(SseClientTransportOptions transportO } /// - public override async Task SendMessageAsync( - JsonRpcMessage message, - CancellationToken cancellationToken = default) + public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + // Immediately dispose the response. SendHttpRequestAsync only returns the response so the auto transport can look at it. + using var response = await SendHttpRequestAsync(message, cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + } + + // This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception. + internal async Task SendHttpRequestAsync(JsonRpcMessage message, CancellationToken cancellationToken) { using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); cancellationToken = sendCts.Token; @@ -59,7 +72,7 @@ public override async Task SendMessageAsync( using var content = new StringContent( JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), Encoding.UTF8, - "application/json" + "application/json; charset=utf-8" ); #endif @@ -73,9 +86,14 @@ public override async Task SendMessageAsync( }; CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId); - using var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); - response.EnsureSuccessStatusCode(); + var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + + // We'll let the caller decide whether to throw or fall back given an unsuccessful response. + if (!response.IsSuccessStatusCode) + { + return response; + } var rpcRequest = message as JsonRpcRequest; JsonRpcMessage? rpcResponseCandidate = null; @@ -93,7 +111,7 @@ public override async Task SendMessageAsync( if (rpcRequest is null) { - return; + return response; } if (rpcResponseCandidate is not JsonRpcMessageWithId messageWithId || messageWithId.Id != rpcRequest.Id) @@ -111,6 +129,8 @@ public override async Task SendMessageAsync( _getReceiveTask = ReceiveUnsolicitedMessagesAsync(); } + + return response; } public override async ValueTask DisposeAsync() @@ -136,7 +156,12 @@ public override async ValueTask DisposeAsync() } finally { - SetDisconnected(); + // If we're auto-detecting the transport and failed to connect, leave the message Channel open for the SSE transport. + // This class isn't directly exposed to public callers, so we don't have to worry about changing the _state in this case. + if (_options.TransportMode is not HttpTransportMode.AutoDetect || _getReceiveTask is not null) + { + SetDisconnected(); + } } } diff --git a/src/ModelContextProtocol/McpJsonUtilities.cs b/src/ModelContextProtocol/McpJsonUtilities.cs index ca9748437..162bc2343 100644 --- a/src/ModelContextProtocol/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/McpJsonUtilities.cs @@ -21,7 +21,7 @@ public static partial class McpJsonUtilities /// /// It additionally turns on the following settings: /// - /// Enables string-based enum serialization as implemented by . + /// Enables defaults. /// Enables as the default ignore condition for properties. /// Enables as the default number handling for number types. /// diff --git a/src/ModelContextProtocol/ProcessHelper.cs b/src/ModelContextProtocol/ProcessHelper.cs index 7bfe99ab4..c8bae0c48 100644 --- a/src/ModelContextProtocol/ProcessHelper.cs +++ b/src/ModelContextProtocol/ProcessHelper.cs @@ -107,6 +107,7 @@ private static int RunProcessAndWaitForExit(string fileName, string arguments, T RedirectStandardOutput = true, RedirectStandardError = true, UseShellExecute = false, + CreateNoWindow = true, }; stdout = null; diff --git a/src/ModelContextProtocol/Protocol/TransportBase.cs b/src/ModelContextProtocol/Protocol/TransportBase.cs index 31b3b146f..9be9c6fa5 100644 --- a/src/ModelContextProtocol/Protocol/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/TransportBase.cs @@ -36,12 +36,20 @@ public abstract partial class TransportBase : ITransport /// Initializes a new instance of the class. /// protected TransportBase(string name, ILoggerFactory? loggerFactory) + : this(name, null, loggerFactory) + { + } + + /// + /// Initializes a new instance of the class with a specified channel to back . + /// + internal TransportBase(string name, Channel? messageChannel, ILoggerFactory? loggerFactory) { Name = name; _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; - // Unbounded channel to prevent blocking on writes - _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions + // Unbounded channel to prevent blocking on writes. Ensure AutoDetectingClientSessionTransport matches this. + _messageChannel = messageChannel ?? Channel.CreateUnbounded(new UnboundedChannelOptions { SingleReader = true, SingleWriter = false, @@ -112,7 +120,7 @@ protected void SetConnected() case StateConnected: return; - + case StateDisconnected: throw new IOException("Transport is already disconnected and can't be reconnected."); diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 366eb23cd..7f91186b1 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,5 +1,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; @@ -9,8 +11,10 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . -internal sealed class AIFunctionMcpServerTool : McpServerTool +internal sealed partial class AIFunctionMcpServerTool : McpServerTool { + private readonly ILogger _logger; + /// /// Creates an instance for a method, specified via a instance. /// @@ -194,7 +198,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool); + return new AIFunctionMcpServerTool(function, tool, options?.Services); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -239,10 +243,11 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider) { AIFunction = function; ProtocolTool = tool; + _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; } /// @@ -277,6 +282,8 @@ public override async ValueTask InvokeAsync( } catch (Exception e) when (e is not OperationCanceledException) { + ToolCallError(request.Params?.Name ?? string.Empty, e); + string errorMessage = e is McpException ? $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : $"An error occurred invoking '{request.Params?.Name}'."; @@ -359,4 +366,7 @@ private static CallToolResponse ConvertAIContentEnumerableToCallToolResponse(IEn IsError = allErrorContent && hasAny }; } + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index ee717530d..30187faad 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -56,7 +56,7 @@ public async Task Connect_TestServer_ShouldProvideServerFields() [Fact] public async Task ListTools_Sse_TestServer() - { + { // arrange // act diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index c987bca90..be8763ae4 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -1,5 +1,6 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; namespace ModelContextProtocol.AspNetCore.Tests; @@ -34,4 +35,112 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } + + [Fact] + public async Task StreamableHttpMode_Works_WithRootEndpoint() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "StreamableHttpTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync("/", new() + { + Endpoint = new Uri("http://localhost/"), + TransportMode = HttpTransportMode.AutoDetect + }); + + Assert.Equal("StreamableHttpTestServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task AutoDetectMode_Works_WithRootEndpoint() + { + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "AutoDetectTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync("/", new() + { + Endpoint = new Uri("http://localhost/"), + TransportMode = HttpTransportMode.AutoDetect + }); + + Assert.Equal("AutoDetectTestServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task AutoDetectMode_Works_WithSseEndpoint() + { + Assert.SkipWhen(Stateless, "SSE endpoint is disabled in stateless mode."); + + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "AutoDetectSseTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync("/sse", new() + { + Endpoint = new Uri("http://localhost/sse"), + TransportMode = HttpTransportMode.AutoDetect + }); + + Assert.Equal("AutoDetectSseTestServer", mcpClient.ServerInfo.Name); + } + + [Fact] + public async Task SseMode_Works_WithSseEndpoint() + { + Assert.SkipWhen(Stateless, "SSE endpoint is disabled in stateless mode."); + + Builder.Services.AddMcpServer(options => + { + options.ServerInfo = new() + { + Name = "SseTestServer", + Version = "1.0.0", + }; + }).WithHttpTransport(ConfigureStateless); + await using var app = Builder.Build(); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync(options: new() + { + Endpoint = new Uri("http://localhost/sse"), + TransportMode = HttpTransportMode.Sse + }); + + Assert.Equal("SseTestServer", mcpClient.ServerInfo.Name); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index cf49fee16..6d1532207 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -20,16 +20,17 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync(string? path = null) + protected async Task ConnectAsync(string? path = null, SseClientTransportOptions? options = null) { + // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; - var sseClientTransportOptions = new SseClientTransportOptions() + await using var transport = new SseClientTransport(options ?? new SseClientTransportOptions() { Endpoint = new Uri($"http://localhost{path}/"), - UseStreamableHttp = UseStreamableHttp, - }; - await using var transport = new SseClientTransport(sseClientTransportOptions, HttpClient, LoggerFactory); + TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, + }, HttpClient, LoggerFactory); + return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 24acd0b92..fc186c400 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -58,6 +58,22 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU Assert.True(true); } + [Fact] + public async Task ConnectAndReceiveMessage_ServerReturningJsonInPostRequest() + { + await using var app = Builder.Build(); + MapAbsoluteEndpointUriMcp(app, respondInJson: true); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectMcpClientAsync(); + + // Send a test message through POST endpoint + await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken); + + Assert.True(true); + } + [Fact] public async Task ConnectAndReceiveNotification_InMemoryServer() { @@ -220,7 +236,7 @@ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() Assert.Equal("Failed to add header '' with value '' from AdditionalHeaders.", ex.Message); } - private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints) + private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, bool respondInJson = false) { var loggerFactory = endpoints.ServiceProvider.GetRequiredService(); var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService>(); @@ -267,7 +283,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints) await Results.BadRequest("Session not started.").ExecuteAsync(context); return; } - var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted); + var message = await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions, context.RequestAborted); if (message is null) { await Results.BadRequest("No message in request body.").ExecuteAsync(context); @@ -276,7 +292,15 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints) await session.OnMessageReceivedAsync(message, context.RequestAborted); context.Response.StatusCode = StatusCodes.Status202Accepted; - await context.Response.WriteAsync("Accepted"); + + if (respondInJson) + { + await context.Response.WriteAsJsonAsync(message, McpJsonUtilities.DefaultOptions, cancellationToken: context.RequestAborted); + } + else + { + await context.Response.WriteAsync("Accepted"); + } }); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index b1b618057..a9e2e5f54 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -9,6 +9,6 @@ public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fix { Endpoint = new Uri("http://localhost/stateless"), Name = "In-memory Streamable HTTP Client", - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index 2f364be01..acfc744b9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -18,7 +18,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem { Endpoint = new Uri("http://localhost/"), Name = "In-memory Streamable HTTP Client", - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }; private async Task StartAsync() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 94540f8c2..d7f8433b3 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -98,7 +98,7 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() await using var transport = new SseClientTransport(new() { Endpoint = new("http://localhost/mcp"), - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -118,7 +118,7 @@ public async Task CanCallToolConcurrently() await using var transport = new SseClientTransport(new() { Endpoint = new("http://localhost/mcp"), - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 17b1234e3..3efc10419 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -105,7 +105,7 @@ public async Task PostRequest_IsUnsupportedMediaType_WithoutJsonContentType() [InlineData("text/event-stream")] [InlineData("application/json")] [InlineData("application/json-text/event-stream")] - public async Task PostRequest_IsNotAcceptable_WithSingleAcceptHeader(string singleAcceptValue) + public async Task PostRequest_IsNotAcceptable_WithSingleSpecificAcceptHeader(string singleAcceptValue) { await StartAsync(); @@ -116,6 +116,20 @@ public async Task PostRequest_IsNotAcceptable_WithSingleAcceptHeader(string sing Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); } + [Theory] + [InlineData("*/*")] + [InlineData("text/event-stream, application/json;q=0.9")] + public async Task PostRequest_IsAcceptable_WithWildcardOrAddedQualityInAcceptHeader(string acceptHeaderValue) + { + await StartAsync(); + + HttpClient.DefaultRequestHeaders.Accept.Clear(); + HttpClient.DefaultRequestHeaders.TryAddWithoutValidation(HeaderNames.Accept, acceptHeaderValue); + + using var response = await HttpClient.PostAsync("", JsonContent(InitializeRequest), TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + [Fact] public async Task GetRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader() { @@ -128,6 +142,22 @@ public async Task GetRequest_IsNotAcceptable_WithoutTextEventStreamAcceptHeader( Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); } + [Theory] + [InlineData("*/*")] + [InlineData("application/json, text/event-stream;q=0.9")] + public async Task GetRequest_IsAcceptable_WithWildcardOrAddedQualityInAcceptHeader(string acceptHeaderValue) + { + await StartAsync(); + + HttpClient.DefaultRequestHeaders.Accept.Clear(); + HttpClient.DefaultRequestHeaders.TryAddWithoutValidation(HeaderNames.Accept, acceptHeaderValue); + + await CallInitializeAndValidateAsync(); + + using var response = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + [Fact] public async Task PostRequest_IsNotFound_WithUnrecognizedSessionId() { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 64505b3d9..7c4366f16 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -15,7 +15,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur { Endpoint = new Uri("http://localhost/"), Name = "In-memory Streamable HTTP Client", - UseStreamableHttp = true, + TransportMode = HttpTransportMode.StreamableHttp, }; [Fact] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index cb98d9bce..db6f1cde4 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,7 +1,9 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; using System.Text.Json; @@ -381,6 +383,45 @@ public async Task SupportsSchemaCreateOptions() ); } + [Fact] + public async Task ToolCallError_LogsErrorMessage() + { + // Arrange + var mockLoggerProvider = new MockLoggerProvider(); + var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider }); + var services = new ServiceCollection(); + services.AddSingleton(loggerFactory); + var serviceProvider = services.BuildServiceProvider(); + + var toolName = "tool-that-throws"; + var exceptionMessage = "Test exception message"; + + McpServerTool tool = McpServerTool.Create(() => + { + throw new InvalidOperationException(exceptionMessage); + }, new() { Name = toolName, Services = serviceProvider }); + + var mockServer = new Mock(); + var request = new RequestContext(mockServer.Object) + { + Params = new CallToolRequestParams() { Name = toolName }, + Services = serviceProvider + }; + + // Act + var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken); + + // Assert + Assert.True(result.IsError); + Assert.Single(result.Content); + Assert.Equal($"An error occurred invoking '{toolName}'.", result.Content[0].Text); + + var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); + Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message); + Assert.IsType(errorLog.Exception); + Assert.Equal(exceptionMessage, errorLog.Exception.Message); + } + private sealed class MyService; private class DisposableToolType : IDisposable diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs new file mode 100644 index 000000000..8f6fbff2c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs @@ -0,0 +1,109 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Tests.Utils; +using System.Net; + +namespace ModelContextProtocol.Tests.Transport; + +public class SseClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + [Fact] + public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost/"), + TransportMode = HttpTransportMode.AutoDetect, + Name = "AutoDetect test client" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + // Simulate successful Streamable HTTP response for initialize + mockHttpHandler.RequestHandler = (request) => + { + if (request.Method == HttpMethod.Post) + { + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("{\"jsonrpc\":\"2.0\",\"id\":\"init-id\",\"result\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{\"tools\":{}}}}"), + Headers = + { + { "Content-Type", "application/json" }, + { "mcp-session-id", "test-session" } + } + }); + } + + // Shouldn't reach here for successful Streamable HTTP + throw new InvalidOperationException("Unexpected request"); + }; + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // The auto-detecting transport should be returned + Assert.NotNull(session); + } + + [Fact] + public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() + { + var options = new SseClientTransportOptions + { + Endpoint = new Uri("http://localhost/"), + TransportMode = HttpTransportMode.AutoDetect, + Name = "AutoDetect test client" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + + var requestCount = 0; + + mockHttpHandler.RequestHandler = (request) => + { + requestCount++; + + if (request.Method == HttpMethod.Post && requestCount == 1) + { + // First POST (Streamable HTTP) fails + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.NotFound, + Content = new StringContent("Streamable HTTP not supported") + }); + } + + if (request.Method == HttpMethod.Get) + { + // SSE connection request + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("event: endpoint\r\ndata: /sse-endpoint\r\n\r\n"), + Headers = { { "Content-Type", "text/event-stream" } } + }); + } + + if (request.Method == HttpMethod.Post && requestCount > 1) + { + // Subsequent POST to SSE endpoint succeeds + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent("accepted") + }); + } + + throw new InvalidOperationException($"Unexpected request: {request.Method}, count: {requestCount}"); + }; + + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + // The auto-detecting transport should be returned + Assert.NotNull(session); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 857e496aa..ae449ac9f 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -17,6 +17,7 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper) Endpoint = new Uri("http://localhost:8080/"), ConnectionTimeout = TimeSpan.FromSeconds(2), Name = "Test Server", + TransportMode = HttpTransportMode.Sse, AdditionalHeaders = new Dictionary { ["test"] = "header" diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 7a7f39c20..b8d8d714b 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -1,9 +1,10 @@ using ModelContextProtocol.Client; +using ModelContextProtocol.Tests.Utils; using System.Runtime.InteropServices; namespace ModelContextProtocol.Tests.Transport; -public class StdioClientTransportTests +public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() @@ -11,10 +12,10 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() string id = Guid.NewGuid().ToString("N"); StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }) : - new(new() { Command = "ls", Arguments = [id] }); + new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }, LoggerFactory) : + new(new() { Command = "ls", Arguments = [id] }, LoggerFactory); - IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken)); + IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains(id, e.ToString()); } }