diff --git a/Directory.Packages.props b/Directory.Packages.props index 78133ed99..2e377c4ff 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -19,6 +19,7 @@ + @@ -26,6 +27,8 @@ + + @@ -62,7 +65,7 @@ - + diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index e4fd42fe8..5ed8ba0d6 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -12,6 +12,8 @@ + + @@ -33,6 +35,7 @@ + diff --git a/README.md b/README.md index 073243065..163d57f8a 100644 --- a/README.md +++ b/README.md @@ -166,24 +166,24 @@ public static class MyPrompts More control is also available, with fine-grained control over configuring the server and how it should handle client requests. For example: ```csharp -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.Text.Json; McpServerOptions options = new() { - ServerInfo = new Implementation() { Name = "MyServer", Version = "1.0.0" }, - Capabilities = new ServerCapabilities() + ServerInfo = new Implementation { Name = "MyServer", Version = "1.0.0" }, + Capabilities = new ServerCapabilities { - Tools = new ToolsCapability() + Tools = new ToolsCapability { ListToolsHandler = (request, cancellationToken) => - ValueTask.FromResult(new ListToolsResult() + ValueTask.FromResult(new ListToolsResult { Tools = [ - new Tool() + new Tool { Name = "echo", Description = "Echoes the input back to the client.", @@ -212,9 +212,9 @@ McpServerOptions options = new() throw new McpException("Missing required argument 'message'"); } - return ValueTask.FromResult(new CallToolResponse() + return ValueTask.FromResult(new CallToolResult { - Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] + Content = [new TextContentBlock { Text = $"Echo: {message}", Type = "text" }] }); } diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index f24b6a17c..c21b328f6 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -2,12 +2,14 @@ using OpenTelemetry.Metrics; using OpenTelemetry.Trace; using TestServerWithHosting.Tools; +using TestServerWithHosting.Resources; var builder = WebApplication.CreateBuilder(args); builder.Services.AddMcpServer() .WithHttpTransport() .WithTools() - .WithTools(); + .WithTools() + .WithResources(); builder.Services.AddOpenTelemetry() .WithTracing(b => b.AddSource("*") diff --git a/samples/AspNetCoreSseServer/Resources/SimpleResourceType.cs b/samples/AspNetCoreSseServer/Resources/SimpleResourceType.cs new file mode 100644 index 000000000..e73ce133c --- /dev/null +++ b/samples/AspNetCoreSseServer/Resources/SimpleResourceType.cs @@ -0,0 +1,12 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; + +namespace TestServerWithHosting.Resources; + +[McpServerResourceType] +public class SimpleResourceType +{ + [McpServerResource, Description("A direct text resource")] + public static string DirectTextResource() => "This is a direct resource"; +} diff --git a/samples/EverythingServer/Tools/AnnotatedMessageTool.cs b/samples/EverythingServer/Tools/AnnotatedMessageTool.cs index c14a2da2f..7f92d0ae1 100644 --- a/samples/EverythingServer/Tools/AnnotatedMessageTool.cs +++ b/samples/EverythingServer/Tools/AnnotatedMessageTool.cs @@ -39,7 +39,7 @@ public static IEnumerable AnnotatedMessage(MessageType messageType if (includeImage) { - contents.Add(new ImageContentBlock() + contents.Add(new ImageContentBlock { Data = TinyImageTool.MCP_TINY_IMAGE.Split(",").Last(), MimeType = "image/png", diff --git a/samples/EverythingServer/Tools/SampleLlmTool.cs b/samples/EverythingServer/Tools/SampleLlmTool.cs index 720cbead1..a58675c30 100644 --- a/samples/EverythingServer/Tools/SampleLlmTool.cs +++ b/samples/EverythingServer/Tools/SampleLlmTool.cs @@ -22,9 +22,9 @@ public static async Task SampleLLM( private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) { - return new CreateMessageRequestParams() + return new CreateMessageRequestParams { - Messages = [new SamplingMessage() + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, diff --git a/samples/ProtectedMCPClient/Program.cs b/samples/ProtectedMCPClient/Program.cs new file mode 100644 index 000000000..516227b37 --- /dev/null +++ b/samples/ProtectedMCPClient/Program.cs @@ -0,0 +1,148 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Diagnostics; +using System.Net; +using System.Text; +using System.Web; + +var serverUrl = "http://localhost:7071/"; + +Console.WriteLine("Protected MCP Client"); +Console.WriteLine($"Connecting to weather server at {serverUrl}..."); +Console.WriteLine(); + +// We can customize a shared HttpClient with a custom handler if desired +var sharedHandler = new SocketsHttpHandler +{ + PooledConnectionLifetime = TimeSpan.FromMinutes(2), + PooledConnectionIdleTimeout = TimeSpan.FromMinutes(1) +}; +var httpClient = new HttpClient(sharedHandler); + +var consoleLoggerFactory = LoggerFactory.Create(builder => +{ + builder.AddConsole(); +}); + +var transport = new SseClientTransport(new() +{ + Endpoint = new Uri(serverUrl), + Name = "Secure Weather Client", + OAuth = new() + { + ClientName = "ProtectedMcpClient", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + } +}, httpClient, consoleLoggerFactory); + +var client = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); + +var tools = await client.ListToolsAsync(); +if (tools.Count == 0) +{ + Console.WriteLine("No tools available on the server."); + return; +} + +Console.WriteLine($"Found {tools.Count} tools on the server."); +Console.WriteLine(); + +if (tools.Any(t => t.Name == "get_alerts")) +{ + Console.WriteLine("Calling get_alerts tool..."); + + var result = await client.CallToolAsync( + "get_alerts", + new Dictionary { { "state", "WA" } } + ); + + Console.WriteLine("Result: " + ((TextContentBlock)result.Content[0]).Text); + Console.WriteLine(); +} + +/// Handles the OAuth authorization URL by starting a local HTTP server and opening a browser. +/// This implementation demonstrates how SDK consumers can provide their own authorization flow. +/// +/// The authorization URL to open in the browser. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// The authorization code extracted from the callback, or null if the operation failed. +static async Task HandleAuthorizationUrlAsync(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) +{ + Console.WriteLine("Starting OAuth authorization flow..."); + Console.WriteLine($"Opening browser to: {authorizationUrl}"); + + var listenerPrefix = redirectUri.GetLeftPart(UriPartial.Authority); + if (!listenerPrefix.EndsWith("/")) listenerPrefix += "/"; + + using var listener = new HttpListener(); + listener.Prefixes.Add(listenerPrefix); + + try + { + listener.Start(); + Console.WriteLine($"Listening for OAuth callback on: {listenerPrefix}"); + + OpenBrowser(authorizationUrl); + + var context = await listener.GetContextAsync(); + var query = HttpUtility.ParseQueryString(context.Request.Url?.Query ?? string.Empty); + var code = query["code"]; + var error = query["error"]; + + string responseHtml = "

Authentication complete

You can close this window now.

"; + byte[] buffer = Encoding.UTF8.GetBytes(responseHtml); + context.Response.ContentLength64 = buffer.Length; + context.Response.ContentType = "text/html"; + context.Response.OutputStream.Write(buffer, 0, buffer.Length); + context.Response.Close(); + + if (!string.IsNullOrEmpty(error)) + { + Console.WriteLine($"Auth error: {error}"); + return null; + } + + if (string.IsNullOrEmpty(code)) + { + Console.WriteLine("No authorization code received"); + return null; + } + + Console.WriteLine("Authorization code received successfully."); + return code; + } + catch (Exception ex) + { + Console.WriteLine($"Error getting auth code: {ex.Message}"); + return null; + } + finally + { + if (listener.IsListening) listener.Stop(); + } +} + +/// +/// Opens the specified URL in the default browser. +/// +/// The URL to open. +static void OpenBrowser(Uri url) +{ + try + { + var psi = new ProcessStartInfo + { + FileName = url.ToString(), + UseShellExecute = true + }; + Process.Start(psi); + } + catch (Exception ex) + { + Console.WriteLine($"Error opening browser. {ex.Message}"); + Console.WriteLine($"Please manually open this URL: {url}"); + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPClient/ProtectedMCPClient.csproj b/samples/ProtectedMCPClient/ProtectedMCPClient.csproj new file mode 100644 index 000000000..d1d476372 --- /dev/null +++ b/samples/ProtectedMCPClient/ProtectedMCPClient.csproj @@ -0,0 +1,18 @@ + + + + Exe + net9.0 + enable + enable + + + + + + + + + + + \ No newline at end of file diff --git a/samples/ProtectedMCPClient/README.md b/samples/ProtectedMCPClient/README.md new file mode 100644 index 000000000..977331a04 --- /dev/null +++ b/samples/ProtectedMCPClient/README.md @@ -0,0 +1,93 @@ +# Protected MCP Client Sample + +This sample demonstrates how to create an MCP client that connects to a protected MCP server using OAuth 2.0 authentication. The client implements a custom OAuth authorization flow with browser-based authentication. + +## Overview + +The Protected MCP Client sample shows how to: +- Connect to an OAuth-protected MCP server +- Handle OAuth 2.0 authorization code flow +- Use custom authorization redirect handling +- Call protected MCP tools with authentication + +## Prerequisites + +- .NET 9.0 or later +- A running TestOAuthServer (for OAuth authentication) +- A running ProtectedMCPServer (for MCP services) + +## Setup and Running + +### Step 1: Start the Test OAuth Server + +First, you need to start the TestOAuthServer which provides OAuth authentication: + +```bash +cd tests\ModelContextProtocol.TestOAuthServer +dotnet run --framework net9.0 +``` + +The OAuth server will start at `https://localhost:7029` + +### Step 2: Start the Protected MCP Server + +Next, start the ProtectedMCPServer which provides the weather tools: + +```bash +cd samples\ProtectedMCPServer +dotnet run +``` + +The protected server will start at `http://localhost:7071` + +### Step 3: Run the Protected MCP Client + +Finally, run this client: + +```bash +cd samples\ProtectedMCPClient +dotnet run +``` + +## What Happens + +1. The client attempts to connect to the protected MCP server at `http://localhost:7071` +2. The server responds with OAuth metadata indicating authentication is required +3. The client initiates OAuth 2.0 authorization code flow: + - Opens a browser to the authorization URL at the OAuth server + - Starts a local HTTP listener on `http://localhost:1179/callback` to receive the authorization code + - Exchanges the authorization code for an access token +4. The client uses the access token to authenticate with the MCP server +5. The client lists available tools and calls the `GetAlerts` tool for Washington state + +## OAuth Configuration + +The client is configured with: +- **Client ID**: `demo-client` +- **Client Secret**: `demo-secret` +- **Redirect URI**: `http://localhost:1179/callback` +- **OAuth Server**: `https://localhost:7029` +- **Protected Resource**: `http://localhost:7071` + +## Available Tools + +Once authenticated, the client can access weather tools including: +- **GetAlerts**: Get weather alerts for a US state +- **GetForecast**: Get weather forecast for a location (latitude/longitude) + +## Troubleshooting + +- Ensure the ASP.NET Core dev certificate is trusted. + ``` + dotnet dev-certs https --clean + dotnet dev-certs https --trust + ``` +- Ensure all three services are running in the correct order +- Check that ports 7029, 7071, and 1179 are available +- If the browser doesn't open automatically, copy the authorization URL from the console and open it manually +- Make sure to allow the OAuth server's self-signed certificate in your browser + +## Key Files + +- `Program.cs`: Main client application with OAuth flow implementation +- `ProtectedMCPClient.csproj`: Project file with dependencies \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Program.cs b/samples/ProtectedMCPServer/Program.cs new file mode 100644 index 000000000..ef70fe731 --- /dev/null +++ b/samples/ProtectedMCPServer/Program.cs @@ -0,0 +1,93 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ProtectedMCPServer.Tools; +using System.Net.Http.Headers; +using System.Security.Claims; + +var builder = WebApplication.CreateBuilder(args); + +var serverUrl = "http://localhost:7071/"; +var inMemoryOAuthServerUrl = "https://localhost:7029/"; + +builder.Services.AddAuthentication(options => +{ + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; +}) +.AddJwtBearer(options => +{ + // Configure to validate tokens from our in-memory OAuth server + options.Authority = inMemoryOAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidAudience = serverUrl, // Validate that the audience matches the resource metadata as suggested in RFC 8707 + ValidIssuer = inMemoryOAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles" + }; + + options.Events = new JwtBearerEvents + { + OnTokenValidated = context => + { + var name = context.Principal?.Identity?.Name ?? "unknown"; + var email = context.Principal?.FindFirstValue("preferred_username") ?? "unknown"; + Console.WriteLine($"Token validated for: {name} ({email})"); + return Task.CompletedTask; + }, + OnAuthenticationFailed = context => + { + Console.WriteLine($"Authentication failed: {context.Exception.Message}"); + return Task.CompletedTask; + }, + OnChallenge = context => + { + Console.WriteLine($"Challenging client to authenticate with Entra ID"); + return Task.CompletedTask; + } + }; +}) +.AddMcp(options => +{ + options.ResourceMetadata = new() + { + Resource = new Uri(serverUrl), + ResourceDocumentation = new Uri("https://docs.example.com/api/weather"), + AuthorizationServers = { new Uri(inMemoryOAuthServerUrl) }, + ScopesSupported = ["mcp:tools"], + }; +}); + +builder.Services.AddAuthorization(); + +builder.Services.AddHttpContextAccessor(); +builder.Services.AddMcpServer() + .WithTools() + .WithHttpTransport(); + +// Configure HttpClientFactory for weather.gov API +builder.Services.AddHttpClient("WeatherApi", client => +{ + client.BaseAddress = new Uri("https://api.weather.gov/"); + client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("weather-tool", "1.0")); +}); + +var app = builder.Build(); + +app.UseAuthentication(); +app.UseAuthorization(); + +// Use the default MCP policy name that we've configured +app.MapMcp().RequireAuthorization(); + +Console.WriteLine($"Starting MCP server with authorization at {serverUrl}"); +Console.WriteLine($"Using in-memory OAuth server at {inMemoryOAuthServerUrl}"); +Console.WriteLine($"Protected Resource Metadata URL: {serverUrl}.well-known/oauth-protected-resource"); +Console.WriteLine("Press Ctrl+C to stop the server"); + +app.Run(serverUrl); diff --git a/samples/ProtectedMCPServer/Properties/launchSettings.json b/samples/ProtectedMCPServer/Properties/launchSettings.json new file mode 100644 index 000000000..31b04db83 --- /dev/null +++ b/samples/ProtectedMCPServer/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "ProtectedMCPServer": { + "commandName": "Project", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "applicationUrl": "http://localhost:7071/" + } + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPServer/ProtectedMCPServer.csproj b/samples/ProtectedMCPServer/ProtectedMCPServer.csproj new file mode 100644 index 000000000..b4c35c779 --- /dev/null +++ b/samples/ProtectedMCPServer/ProtectedMCPServer.csproj @@ -0,0 +1,15 @@ + + + + net9.0 + enable + enable + 783daef3-9c45-408d-a1d3-7caf44724f39 + + + + + + + + \ No newline at end of file diff --git a/samples/ProtectedMCPServer/README.md b/samples/ProtectedMCPServer/README.md new file mode 100644 index 000000000..f0ac708a0 --- /dev/null +++ b/samples/ProtectedMCPServer/README.md @@ -0,0 +1,125 @@ +# Protected MCP Server Sample + +This sample demonstrates how to create an MCP server that requires OAuth 2.0 authentication to access its tools and resources. The server provides weather-related tools protected by JWT bearer token authentication. + +## Overview + +The Protected MCP Server sample shows how to: +- Create an MCP server with OAuth 2.0 protection +- Configure JWT bearer token authentication +- Implement protected MCP tools and resources +- Integrate with ASP.NET Core authentication and authorization +- Provide OAuth resource metadata for client discovery + +## Prerequisites + +- .NET 9.0 or later +- A running TestOAuthServer (for OAuth authentication) + +## Setup and Running + +### Step 1: Start the Test OAuth Server + +First, you need to start the TestOAuthServer which issues access tokens: + +```bash +cd tests\ModelContextProtocol.TestOAuthServer +dotnet run --framework net9.0 +``` + +The OAuth server will start at `https://localhost:7029` + +### Step 2: Start the Protected MCP Server + +Run this protected server: + +```bash +cd samples\ProtectedMCPServer +dotnet run +``` + +The protected server will start at `http://localhost:7071` + +### Step 3: Test with Protected MCP Client + +You can test the server using the ProtectedMCPClient sample: + +```bash +cd samples\ProtectedMCPClient +dotnet run +``` + +## What the Server Provides + +### Protected Resources + +- **MCP Endpoint**: `http://localhost:7071/` (requires authentication) +- **OAuth Resource Metadata**: `http://localhost:7071/.well-known/oauth-protected-resource` + +### Available Tools + +The server provides weather-related tools that require authentication: + +1. **GetAlerts**: Get weather alerts for a US state + - Parameter: `state` (string) - 2-letter US state abbreviation + - Example: `GetAlerts` with `state: "WA"` + +2. **GetForecast**: Get weather forecast for a location + - Parameters: + - `latitude` (double) - Latitude coordinate + - `longitude` (double) - Longitude coordinate + - Example: `GetForecast` with `latitude: 47.6062, longitude: -122.3321` + +### Authentication Configuration + +The server is configured to: +- Accept JWT bearer tokens from the OAuth server at `https://localhost:7029` +- Validate token audience as `demo-client` +- Require tokens to have appropriate scopes (`mcp:tools`) +- Provide OAuth resource metadata for client discovery + +## Architecture + +The server uses: +- **ASP.NET Core** for hosting and HTTP handling +- **JWT Bearer Authentication** for token validation +- **MCP Authentication Extensions** for OAuth resource metadata +- **HttpClient** for calling the weather.gov API +- **Authorization** to protect MCP endpoints + +## Configuration Details + +- **Server URL**: `http://localhost:7071` +- **OAuth Server**: `https://localhost:7029` +- **Demo Client ID**: `demo-client` + +## Testing Without Client + +You can test the server directly using HTTP tools: + +1. Get an access token from the OAuth server +2. Include the token in the `Authorization: Bearer ` header +3. Make requests to the MCP endpoints + +## External Dependencies + +The weather tools use the National Weather Service API at `api.weather.gov` to fetch real weather data. + +## Troubleshooting + +- Ensure the ASP.NET Core dev certificate is trusted. + ``` + dotnet dev-certs https --clean + dotnet dev-certs https --trust + ``` +- Ensure the TestOAuthServer is running first +- Check that port 7071 is available +- Verify the OAuth server is accessible at `https://localhost:7029` +- Check console output for authentication events and errors + +## Key Files + +- `Program.cs`: Server setup with authentication and MCP configuration +- `Tools/WeatherTools.cs`: Weather tool implementations +- `Tools/HttpClientExt.cs`: HTTP client extensions +- `Properties/launchSettings.json`: Development launch configuration \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/HttpClientExt.cs b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs new file mode 100644 index 000000000..f7b2b5499 --- /dev/null +++ b/samples/ProtectedMCPServer/Tools/HttpClientExt.cs @@ -0,0 +1,13 @@ +using System.Text.Json; + +namespace ModelContextProtocol; + +internal static class HttpClientExt +{ + public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri) + { + using var response = await client.GetAsync(requestUri); + response.EnsureSuccessStatusCode(); + return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); + } +} \ No newline at end of file diff --git a/samples/ProtectedMCPServer/Tools/WeatherTools.cs b/samples/ProtectedMCPServer/Tools/WeatherTools.cs new file mode 100644 index 000000000..7c8c08514 --- /dev/null +++ b/samples/ProtectedMCPServer/Tools/WeatherTools.cs @@ -0,0 +1,67 @@ +using ModelContextProtocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Globalization; +using System.Text.Json; + +namespace ProtectedMCPServer.Tools; + +[McpServerToolType] +public sealed class WeatherTools +{ + private readonly IHttpClientFactory _httpClientFactory; + + public WeatherTools(IHttpClientFactory httpClientFactory) + { + _httpClientFactory = httpClientFactory; + } + + [McpServerTool, Description("Get weather alerts for a US state.")] + public async Task GetAlerts( + [Description("The US state to get alerts for. Use the 2 letter abbreviation for the state (e.g. NY).")] string state) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}"); + var jsonElement = jsonDocument.RootElement; + var alerts = jsonElement.GetProperty("features").EnumerateArray(); + + if (!alerts.Any()) + { + return "No active alerts for this state."; + } + + return string.Join("\n--\n", alerts.Select(alert => + { + JsonElement properties = alert.GetProperty("properties"); + return $""" + Event: {properties.GetProperty("event").GetString()} + Area: {properties.GetProperty("areaDesc").GetString()} + Severity: {properties.GetProperty("severity").GetString()} + Description: {properties.GetProperty("description").GetString()} + Instruction: {properties.GetProperty("instruction").GetString()} + """; + })); + } + + [McpServerTool, Description("Get weather forecast for a location.")] + public async Task GetForecast( + [Description("Latitude of the location.")] double latitude, + [Description("Longitude of the location.")] double longitude) + { + var client = _httpClientFactory.CreateClient("WeatherApi"); + var pointUrl = string.Create(CultureInfo.InvariantCulture, $"/points/{latitude},{longitude}"); + using var jsonDocument = await client.ReadJsonDocumentAsync(pointUrl); + var forecastUrl = jsonDocument.RootElement.GetProperty("properties").GetProperty("forecast").GetString() + ?? throw new Exception($"No forecast URL provided by {client.BaseAddress}points/{latitude},{longitude}"); + + using var forecastDocument = await client.ReadJsonDocumentAsync(forecastUrl); + var periods = forecastDocument.RootElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); + + return string.Join("\n---\n", periods.Select(period => $""" + {period.GetProperty("name").GetString()} + Temperature: {period.GetProperty("temperature").GetInt32()}°F + Wind: {period.GetProperty("windSpeed").GetString()} {period.GetProperty("windDirection").GetString()} + Forecast: {period.GetProperty("detailedForecast").GetString()} + """)); + } +} diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs index aa25db70f..423af627f 100644 --- a/samples/QuickstartClient/Program.cs +++ b/samples/QuickstartClient/Program.cs @@ -3,6 +3,8 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using ModelContextProtocol.Client; +using System.Diagnostics; +using System.Runtime.CompilerServices; var builder = Host.CreateApplicationBuilder(args); @@ -89,6 +91,12 @@ static void PromptForInput() [var script] when script.EndsWith(".py") => ("python", args), [var script] when script.EndsWith(".js") => ("node", args), [var script] when Directory.Exists(script) || (File.Exists(script) && script.EndsWith(".csproj")) => ("dotnet", ["run", "--project", script]), - _ => ("dotnet", ["run", "--project", "../QuickstartWeatherServer"]) + _ => ("dotnet", ["run", "--project", Path.Combine(GetCurrentSourceDirectory(), "../QuickstartWeatherServer")]) }; +} + +static string GetCurrentSourceDirectory([CallerFilePath] string? currentFile = null) +{ + Debug.Assert(!string.IsNullOrWhiteSpace(currentFile)); + return Path.GetDirectoryName(currentFile) ?? throw new InvalidOperationException("Unable to determine source directory."); } \ No newline at end of file diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index aa03d7fb4..a096f9301 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -25,9 +25,9 @@ public static async Task SampleLLM( private static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) { - return new CreateMessageRequestParams() + return new CreateMessageRequestParams { - Messages = [new SamplingMessage() + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 137785bb3..b3d159455 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -6,7 +6,7 @@ https://github.com/modelcontextprotocol/csharp-sdk git 0.3.0 - preview.1 + preview.2 ModelContextProtocolOfficial © Anthropic and Contributors. ModelContextProtocol;mcp;ai;llm diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs new file mode 100644 index 000000000..4c720c65c --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationDefaults.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Default values used by MCP authentication. +/// +public static class McpAuthenticationDefaults +{ + /// + /// The default value used for authentication scheme name. + /// + public const string AuthenticationScheme = "McpAuth"; + + /// + /// The default value used for authentication scheme display name. + /// + public const string DisplayName = "MCP Authentication"; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs new file mode 100644 index 000000000..0d4302252 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationEvents.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Represents the authentication events for Model Context Protocol. +/// +public class McpAuthenticationEvents +{ + /// + /// Gets or sets the function that is invoked when resource metadata is requested. + /// + /// + /// This function is called when a resource metadata request is made to the protected resource metadata endpoint. + /// The implementer should set the property + /// to provide the appropriate metadata for the current request. + /// + public Func OnResourceMetadataRequest { get; set; } = context => Task.CompletedTask; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs new file mode 100644 index 000000000..f103357c8 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationExtensions.cs @@ -0,0 +1,47 @@ +using Microsoft.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Authentication; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for adding MCP authorization support to ASP.NET Core applications. +/// +public static class McpAuthenticationExtensions +{ + /// + /// Adds MCP authorization support to the application. + /// + /// The authentication builder. + /// An action to configure MCP authentication options. + /// The authentication builder for chaining. + public static AuthenticationBuilder AddMcp( + this AuthenticationBuilder builder, + Action? configureOptions = null) + { + return AddMcp( + builder, + McpAuthenticationDefaults.AuthenticationScheme, + McpAuthenticationDefaults.DisplayName, + configureOptions); + } + + /// + /// Adds MCP authorization support to the application with a custom scheme name. + /// + /// The authentication builder. + /// The authentication scheme name to use. + /// The display name for the authentication scheme. + /// An action to configure MCP authentication options. + /// The authentication builder for chaining. + public static AuthenticationBuilder AddMcp( + this AuthenticationBuilder builder, + string authenticationScheme, + string displayName, + Action? configureOptions = null) + { + return builder.AddScheme( + authenticationScheme, + displayName, + configureOptions); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs new file mode 100644 index 000000000..942db1b65 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationHandler.cs @@ -0,0 +1,157 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Authentication; +using System.Text.Encodings.Web; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Authentication handler for MCP protocol that adds resource metadata to challenge responses +/// and handles resource metadata endpoint requests. +/// +public class McpAuthenticationHandler : AuthenticationHandler, IAuthenticationRequestHandler +{ + /// + /// Initializes a new instance of the class. + /// + public McpAuthenticationHandler( + IOptionsMonitor options, + ILoggerFactory logger, + UrlEncoder encoder) + : base(options, logger, encoder) + { + } + + /// + public async Task HandleRequestAsync() + { + // Check if the request is for the resource metadata endpoint + string requestPath = Request.Path.Value ?? string.Empty; + + string expectedMetadataPath = Options.ResourceMetadataUri?.ToString() ?? string.Empty; + if (Options.ResourceMetadataUri != null && !Options.ResourceMetadataUri.IsAbsoluteUri) + { + // For relative URIs, it's just the path component. + expectedMetadataPath = Options.ResourceMetadataUri.OriginalString; + } + + // If the path doesn't match, let the request continue through the pipeline + if (!string.Equals(requestPath, expectedMetadataPath, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + var cancellationToken = Request.HttpContext.RequestAborted; + await HandleResourceMetadataRequestAsync(cancellationToken); + return true; + } + + /// + /// Gets the base URL from the current request, including scheme, host, and path base. + /// + private string GetBaseUrl() => $"{Request.Scheme}://{Request.Host}{Request.PathBase}"; + + /// + /// Gets the absolute URI for the resource metadata endpoint. + /// + private string GetAbsoluteResourceMetadataUri() + { + var resourceMetadataUri = Options.ResourceMetadataUri; + + string currentPath = resourceMetadataUri?.ToString() ?? string.Empty; + + if (resourceMetadataUri != null && resourceMetadataUri.IsAbsoluteUri) + { + return currentPath; + } + + // For relative URIs, combine with the base URL + string baseUrl = GetBaseUrl(); + string relativePath = resourceMetadataUri?.OriginalString.TrimStart('/') ?? string.Empty; + + if (!Uri.TryCreate($"{baseUrl.TrimEnd('/')}/{relativePath}", UriKind.Absolute, out var absoluteUri)) + { + throw new InvalidOperationException($"Could not create absolute URI for resource metadata. Base URL: {baseUrl}, Relative Path: {relativePath}"); + } + + return absoluteUri.ToString(); + } + + /// + /// Handles the resource metadata request. + /// + /// A token to cancel the operation. + private async Task HandleResourceMetadataRequestAsync(CancellationToken cancellationToken = default) + { + var resourceMetadata = Options.ResourceMetadata; + + if (Options.Events.OnResourceMetadataRequest is not null) + { + var context = new ResourceMetadataRequestContext(Request.HttpContext, Scheme, Options) + { + ResourceMetadata = CloneResourceMetadata(resourceMetadata), + }; + + await Options.Events.OnResourceMetadataRequest(context); + } + + + if (resourceMetadata == null) + { + throw new InvalidOperationException("ResourceMetadata has not been configured. Please set McpAuthenticationOptions.ResourceMetadata."); + } + + await Results.Json(resourceMetadata, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))).ExecuteAsync(Context); + } + + /// + // If no forwarding is configured, this handler doesn't perform authentication + protected override async Task HandleAuthenticateAsync() => AuthenticateResult.NoResult(); + + /// + protected override Task HandleChallengeAsync(AuthenticationProperties properties) + { + // Get the absolute URI for the resource metadata + string rawPrmDocumentUri = GetAbsoluteResourceMetadataUri(); + + properties ??= new AuthenticationProperties(); + + // Store the resource_metadata in properties in case other handlers need it + properties.Items["resource_metadata"] = rawPrmDocumentUri; + + // Add the WWW-Authenticate header with Bearer scheme and resource metadata + string headerValue = $"Bearer realm=\"{Scheme.Name}\", resource_metadata=\"{rawPrmDocumentUri}\""; + Response.Headers.Append("WWW-Authenticate", headerValue); + + return base.HandleChallengeAsync(properties); + } + + internal static ProtectedResourceMetadata? CloneResourceMetadata(ProtectedResourceMetadata? resourceMetadata) + { + if (resourceMetadata is null) + { + return null; + } + + return new ProtectedResourceMetadata + { + Resource = resourceMetadata.Resource, + AuthorizationServers = [.. resourceMetadata.AuthorizationServers], + BearerMethodsSupported = [.. resourceMetadata.BearerMethodsSupported], + ScopesSupported = [.. resourceMetadata.ScopesSupported], + JwksUri = resourceMetadata.JwksUri, + ResourceSigningAlgValuesSupported = resourceMetadata.ResourceSigningAlgValuesSupported is not null ? [.. resourceMetadata.ResourceSigningAlgValuesSupported] : null, + ResourceName = resourceMetadata.ResourceName, + ResourceDocumentation = resourceMetadata.ResourceDocumentation, + ResourcePolicyUri = resourceMetadata.ResourcePolicyUri, + ResourceTosUri = resourceMetadata.ResourceTosUri, + TlsClientCertificateBoundAccessTokens = resourceMetadata.TlsClientCertificateBoundAccessTokens, + AuthorizationDetailsTypesSupported = resourceMetadata.AuthorizationDetailsTypesSupported is not null ? [.. resourceMetadata.AuthorizationDetailsTypesSupported] : null, + DpopSigningAlgValuesSupported = resourceMetadata.DpopSigningAlgValuesSupported is not null ? [.. resourceMetadata.DpopSigningAlgValuesSupported] : null, + DpopBoundAccessTokensRequired = resourceMetadata.DpopBoundAccessTokensRequired + }; + } + +} diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs new file mode 100644 index 000000000..ecb6c6c82 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/McpAuthenticationOptions.cs @@ -0,0 +1,49 @@ +using Microsoft.AspNetCore.Authentication; +using ModelContextProtocol.Authentication; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Options for the MCP authentication handler. +/// +public class McpAuthenticationOptions : AuthenticationSchemeOptions +{ + private static readonly Uri DefaultResourceMetadataUri = new("/.well-known/oauth-protected-resource", UriKind.Relative); + + /// + /// Initializes a new instance of the class. + /// + public McpAuthenticationOptions() + { + // "Bearer" is JwtBearerDefaults.AuthenticationScheme, but we don't have a reference to the JwtBearer package here. + ForwardAuthenticate = "Bearer"; + ResourceMetadataUri = DefaultResourceMetadataUri; + Events = new McpAuthenticationEvents(); + } + + /// + /// Gets or sets the events used to handle authentication events. + /// + public new McpAuthenticationEvents Events + { + get { return (McpAuthenticationEvents)base.Events!; } + set { base.Events = value; } + } + + /// + /// The URI to the resource metadata document. + /// + /// + /// This URI will be included in the WWW-Authenticate header when a 401 response is returned. + /// + public Uri ResourceMetadataUri { get; set; } + + /// + /// Gets or sets the protected resource metadata. + /// + /// + /// This contains the OAuth metadata for the protected resource, including authorization servers, + /// supported scopes, and other information needed for clients to authenticate. + /// + public ProtectedResourceMetadata? ResourceMetadata { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/Authentication/ResourceMetadataRequestContext.cs b/src/ModelContextProtocol.AspNetCore/Authentication/ResourceMetadataRequestContext.cs new file mode 100644 index 000000000..0d064123e --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/Authentication/ResourceMetadataRequestContext.cs @@ -0,0 +1,30 @@ +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using ModelContextProtocol.Authentication; + +namespace ModelContextProtocol.AspNetCore.Authentication; + +/// +/// Context for resource metadata request events. +/// +public class ResourceMetadataRequestContext : HandleRequestContext +{ + /// + /// Initializes a new instance of the class. + /// + /// The HTTP context. + /// The authentication scheme. + /// The authentication options. + public ResourceMetadataRequestContext( + HttpContext context, + AuthenticationScheme scheme, + McpAuthenticationOptions options) + : base(context, scheme, options) + { + } + + /// + /// Gets or sets the protected resource metadata for the current request. + /// + public ProtectedResourceMetadata? ResourceMetadata { get; set; } +} diff --git a/src/ModelContextProtocol.Core/AIContentExtensions.cs b/src/ModelContextProtocol.Core/AIContentExtensions.cs index e8c5d7e33..6b6a9c780 100644 --- a/src/ModelContextProtocol.Core/AIContentExtensions.cs +++ b/src/ModelContextProtocol.Core/AIContentExtensions.cs @@ -39,6 +39,29 @@ public static ChatMessage ToChatMessage(this PromptMessage promptMessage) }; } + /// + /// Converts a to a object. + /// + /// The tool result to convert. + /// The identifier for the function call request that triggered the tool invocation. + /// A object created from the tool result. + /// + /// This method transforms a protocol-specific from the Model Context Protocol + /// into a standard object that can be used with AI client libraries. It produces a + /// message containing a with result as a + /// serialized . + /// + public static ChatMessage ToChatMessage(this CallToolResult result, string callId) + { + Throw.IfNull(result); + Throw.IfNull(callId); + + return new(ChatRole.Tool, [new FunctionResultContent(callId, JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResult)) + { + RawRepresentation = result, + }]); + } + /// /// Converts a to a list of objects. /// @@ -188,21 +211,21 @@ internal static ContentBlock ToContent(this AIContent content) => Text = textContent.Text, }, - DataContent dataContent when dataContent.HasTopLevelMediaType("image") => new ImageContentBlock() + DataContent dataContent when dataContent.HasTopLevelMediaType("image") => new ImageContentBlock { Data = dataContent.Base64Data.ToString(), MimeType = dataContent.MediaType, }, - DataContent dataContent when dataContent.HasTopLevelMediaType("audio") => new AudioContentBlock() + DataContent dataContent when dataContent.HasTopLevelMediaType("audio") => new AudioContentBlock { Data = dataContent.Base64Data.ToString(), MimeType = dataContent.MediaType, }, - DataContent dataContent => new EmbeddedResourceBlock() + DataContent dataContent => new EmbeddedResourceBlock { - Resource = new BlobResourceContents() + Resource = new BlobResourceContents { Blob = dataContent.Base64Data.ToString(), MimeType = dataContent.MediaType, diff --git a/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs b/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs new file mode 100644 index 000000000..1cc081895 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthenticatingMcpHttpClient.cs @@ -0,0 +1,118 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using System.Net.Http.Headers; + +namespace ModelContextProtocol.Authentication; + +/// +/// A delegating handler that adds authentication tokens to requests and handles 401 responses. +/// +internal sealed class AuthenticatingMcpHttpClient(HttpClient httpClient, ClientOAuthProvider credentialProvider) : McpHttpClient(httpClient) +{ + // Select first supported scheme as the default + private string _currentScheme = credentialProvider.SupportedSchemes.FirstOrDefault() ?? + throw new ArgumentException("Authorization provider must support at least one authentication scheme.", nameof(credentialProvider)); + + /// + /// Sends an HTTP request with authentication handling. + /// + internal override async Task SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken) + { + if (request.Headers.Authorization == null) + { + await AddAuthorizationHeaderAsync(request, _currentScheme, cancellationToken).ConfigureAwait(false); + } + + var response = await base.SendAsync(request, message, cancellationToken).ConfigureAwait(false); + + if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) + { + return await HandleUnauthorizedResponseAsync(request, message, response, cancellationToken).ConfigureAwait(false); + } + + return response; + } + + /// + /// Handles a 401 Unauthorized response by attempting to authenticate and retry the request. + /// + private async Task HandleUnauthorizedResponseAsync( + HttpRequestMessage originalRequest, + JsonRpcMessage? originalJsonRpcMessage, + HttpResponseMessage response, + CancellationToken cancellationToken) + { + // Gather the schemes the server wants us to use from WWW-Authenticate headers + var serverSchemes = ExtractServerSupportedSchemes(response); + + if (!serverSchemes.Contains(_currentScheme)) + { + // Find the first server scheme that's in our supported set + var bestSchemeMatch = serverSchemes.Intersect(credentialProvider.SupportedSchemes, StringComparer.OrdinalIgnoreCase).FirstOrDefault(); + + if (bestSchemeMatch is not null) + { + _currentScheme = bestSchemeMatch; + } + else if (serverSchemes.Count > 0) + { + // If no match was found, either throw an exception or use default + throw new McpException( + $"The server does not support any of the provided authentication schemes." + + $"Server supports: [{string.Join(", ", serverSchemes)}], " + + $"Provider supports: [{string.Join(", ", credentialProvider.SupportedSchemes)}]."); + } + } + + // Try to handle the 401 response with the selected scheme + await credentialProvider.HandleUnauthorizedResponseAsync(_currentScheme, response, cancellationToken).ConfigureAwait(false); + + using var retryRequest = new HttpRequestMessage(originalRequest.Method, originalRequest.RequestUri); + + // Copy headers except Authorization which we'll set separately + foreach (var header in originalRequest.Headers) + { + if (!header.Key.Equals("Authorization", StringComparison.OrdinalIgnoreCase)) + { + retryRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } + + await AddAuthorizationHeaderAsync(retryRequest, _currentScheme, cancellationToken).ConfigureAwait(false); + return await base.SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false); + } + + /// + /// Extracts the authentication schemes that the server supports from the WWW-Authenticate headers. + /// + private static HashSet ExtractServerSupportedSchemes(HttpResponseMessage response) + { + var serverSchemes = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var header in response.Headers.WwwAuthenticate) + { + serverSchemes.Add(header.Scheme); + } + + return serverSchemes; + } + + /// + /// Adds an authorization header to the request. + /// + private async Task AddAuthorizationHeaderAsync(HttpRequestMessage request, string scheme, CancellationToken cancellationToken) + { + if (request.RequestUri is null) + { + return; + } + + var token = await credentialProvider.GetCredentialAsync(scheme, request.RequestUri, cancellationToken).ConfigureAwait(false); + if (string.IsNullOrEmpty(token)) + { + return; + } + + request.Headers.Authorization = new AuthenticationHeaderValue(scheme, token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs b/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs new file mode 100644 index 000000000..d3c33231f --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthorizationRedirectDelegate.cs @@ -0,0 +1,28 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a method that handles the OAuth authorization URL and returns the authorization code. +/// +/// The authorization URL that the user needs to visit. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// A task that represents the asynchronous operation. The task result contains the authorization code if successful, or null if the operation failed or was cancelled. +/// +/// +/// This delegate provides SDK consumers with full control over how the OAuth authorization flow is handled. +/// Implementers can choose to: +/// +/// +/// Start a local HTTP server and open a browser (default behavior) +/// Display the authorization URL to the user for manual handling +/// Integrate with a custom UI or authentication flow +/// Use a different redirect mechanism altogether +/// +/// +/// The implementation should handle user interaction to visit the authorization URL and extract +/// the authorization code from the callback. The authorization code is typically provided as +/// a query parameter in the redirect URI callback. +/// +/// +public delegate Task AuthorizationRedirectDelegate(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs b/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs new file mode 100644 index 000000000..e94fce7a9 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/AuthorizationServerMetadata.cs @@ -0,0 +1,69 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents the metadata about an OAuth authorization server. +/// +internal sealed class AuthorizationServerMetadata +{ + /// + /// The authorization endpoint URI. + /// + [JsonPropertyName("authorization_endpoint")] + public Uri AuthorizationEndpoint { get; set; } = null!; + + /// + /// The token endpoint URI. + /// + [JsonPropertyName("token_endpoint")] + public Uri TokenEndpoint { get; set; } = null!; + + /// + /// The registration endpoint URI. + /// + [JsonPropertyName("registration_endpoint")] + public Uri? RegistrationEndpoint { get; set; } + + /// + /// The revocation endpoint URI. + /// + [JsonPropertyName("revocation_endpoint")] + public Uri? RevocationEndpoint { get; set; } + + /// + /// The response types supported by the authorization server. + /// + [JsonPropertyName("response_types_supported")] + public List? ResponseTypesSupported { get; set; } + + /// + /// The grant types supported by the authorization server. + /// + [JsonPropertyName("grant_types_supported")] + public List? GrantTypesSupported { get; set; } + + /// + /// The token endpoint authentication methods supported by the authorization server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public List? TokenEndpointAuthMethodsSupported { get; set; } + + /// + /// The code challenge methods supported by the authorization server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public List? CodeChallengeMethodsSupported { get; set; } + + /// + /// The issuer URI of the authorization server. + /// + [JsonPropertyName("issuer")] + public Uri? Issuer { get; set; } + + /// + /// The scopes supported by the authorization server. + /// + [JsonPropertyName("scopes_supported")] + public List? ScopesSupported { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs new file mode 100644 index 000000000..686316f55 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -0,0 +1,99 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Provides configuration options for the . +/// +public sealed class ClientOAuthOptions +{ + /// + /// Gets or sets the OAuth redirect URI. + /// + public required Uri RedirectUri { get; set; } + + /// + /// Gets or sets the OAuth client ID. If not provided, the client will attempt to register dynamically. + /// + public string? ClientId { get; set; } + + /// + /// Gets or sets the OAuth client secret. + /// + /// + /// This is optional for public clients or when using PKCE without client authentication. + /// + public string? ClientSecret { get; set; } + + /// + /// Gets or sets the OAuth scopes to request. + /// + /// + /// + /// When specified, these scopes will be used instead of the scopes advertised by the protected resource. + /// If not specified, the provider will use the scopes from the protected resource metadata. + /// + /// + /// Common OAuth scopes include "openid", "profile", "email", etc. + /// + /// + public IEnumerable? Scopes { get; set; } + + /// + /// Gets or sets the authorization redirect delegate for handling the OAuth authorization flow. + /// + /// + /// + /// This delegate is responsible for handling the OAuth authorization URL and obtaining the authorization code. + /// If not specified, a default implementation will be used that prompts the user to enter the code manually. + /// + /// + /// Custom implementations might open a browser, start an HTTP listener, or use other mechanisms to capture + /// the authorization code from the OAuth redirect. + /// + /// + public AuthorizationRedirectDelegate? AuthorizationRedirectDelegate { get; set; } + + /// + /// Gets or sets the authorization server selector function. + /// + /// + /// + /// This function is used to select which authorization server to use when multiple servers are available. + /// If not specified, the first available server will be selected. + /// + /// + /// The function receives a list of available authorization server URIs and should return the selected server, + /// or null if no suitable server is found. + /// + /// + public Func, Uri?>? AuthServerSelector { get; set; } + + /// + /// Gets or sets the client name to use during dynamic client registration. + /// + /// + /// This is a human-readable name for the client that may be displayed to users during authorization. + /// Only used when a is not specified. + /// + public string? ClientName { get; set; } + + /// + /// Gets or sets the client URI to use during dynamic client registration. + /// + /// + /// This should be a URL pointing to the client's home page or information page. + /// Only used when a is not specified. + /// + public Uri? ClientUri { get; set; } + + /// + /// Gets or sets additional parameters to include in the query string of the OAuth authorization request + /// providing extra information or fulfilling specific requirements of the OAuth provider. + /// + /// + /// + /// Parameters specified cannot override or append to any automatically set parameters like the "redirect_uri" + /// which should instead be configured via . + /// + /// + public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs new file mode 100644 index 000000000..96356028f --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -0,0 +1,687 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using System.Collections.Specialized; +using System.Diagnostics.CodeAnalysis; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using System.Web; + +namespace ModelContextProtocol.Authentication; + +/// +/// A generic implementation of an OAuth authorization provider for MCP. This does not do any advanced token +/// protection or caching - it acquires a token and server metadata and holds it in memory. +/// This is suitable for demonstration and development purposes. +/// +internal sealed partial class ClientOAuthProvider +{ + /// + /// The Bearer authentication scheme. + /// + private const string BearerScheme = "Bearer"; + + private readonly Uri _serverUrl; + private readonly Uri _redirectUri; + private readonly string[]? _scopes; + private readonly IDictionary _additionalAuthorizationParameters; + private readonly Func, Uri?> _authServerSelector; + private readonly AuthorizationRedirectDelegate _authorizationRedirectDelegate; + + // _clientName and _client URI is used for dynamic client registration (RFC 7591) + private readonly string? _clientName; + private readonly Uri? _clientUri; + + private readonly HttpClient _httpClient; + private readonly ILogger _logger; + + private string? _clientId; + private string? _clientSecret; + + private TokenContainer? _token; + private AuthorizationServerMetadata? _authServerMetadata; + + /// + /// Initializes a new instance of the class using the specified options. + /// + /// The MCP server URL. + /// The OAuth provider configuration options. + /// The HTTP client to use for OAuth requests. If null, a default HttpClient will be used. + /// A logger factory to handle diagnostic messages. + /// Thrown when serverUrl or options are null. + public ClientOAuthProvider( + Uri serverUrl, + ClientOAuthOptions options, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + _serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl)); + _httpClient = httpClient ?? new HttpClient(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + if (options is null) + { + throw new ArgumentNullException(nameof(options)); + } + + _clientId = options.ClientId; + _clientSecret = options.ClientSecret; + _redirectUri = options.RedirectUri ?? throw new ArgumentException("ClientOAuthOptions.RedirectUri must configured."); + _clientName = options.ClientName; + _clientUri = options.ClientUri; + _scopes = options.Scopes?.ToArray(); + _additionalAuthorizationParameters = options.AdditionalAuthorizationParameters; + + // Set up authorization server selection strategy + _authServerSelector = options.AuthServerSelector ?? DefaultAuthServerSelector; + + // Set up authorization URL handler (use default if not provided) + _authorizationRedirectDelegate = options.AuthorizationRedirectDelegate ?? DefaultAuthorizationUrlHandler; + } + + /// + /// Default authorization server selection strategy that selects the first available server. + /// + /// List of available authorization servers. + /// The selected authorization server, or null if none are available. + private static Uri? DefaultAuthServerSelector(IReadOnlyList availableServers) => availableServers.FirstOrDefault(); + + /// + /// Default authorization URL handler that displays the URL to the user for manual input. + /// + /// The authorization URL to handle. + /// The redirect URI where the authorization code will be sent. + /// The cancellation token. + /// The authorization code entered by the user, or null if none was provided. + private static Task DefaultAuthorizationUrlHandler(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) + { + Console.WriteLine($"Please open the following URL in your browser to authorize the application:"); + Console.WriteLine($"{authorizationUrl}"); + Console.WriteLine(); + Console.Write("Enter the authorization code from the redirect URL: "); + var authorizationCode = Console.ReadLine(); + return Task.FromResult(authorizationCode); + } + + /// + /// Gets the collection of authentication schemes supported by this provider. + /// + /// + /// + /// This property returns all authentication schemes that this provider can handle, + /// allowing clients to select the appropriate scheme based on server capabilities. + /// + /// + /// Common values include "Bearer" for JWT tokens, "Basic" for username/password authentication, + /// and "Negotiate" for integrated Windows authentication. + /// + /// + public IEnumerable SupportedSchemes => [BearerScheme]; + + /// + /// Gets an authentication token or credential for authenticating requests to a resource + /// using the specified authentication scheme. + /// + /// The authentication scheme to use. + /// The URI of the resource requiring authentication. + /// A token to cancel the operation. + /// An authentication token string or null if no token could be obtained for the specified scheme. + public async Task GetCredentialAsync(string scheme, Uri resourceUri, CancellationToken cancellationToken = default) + { + ThrowIfNotBearerScheme(scheme); + + // Return the token if it's valid + if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + { + return _token.AccessToken; + } + + // Try to refresh the token if we have a refresh token + if (_token?.RefreshToken != null && _authServerMetadata != null) + { + var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + if (newToken != null) + { + _token = newToken; + return _token.AccessToken; + } + } + + // No valid token - auth handler will trigger the 401 flow + return null; + } + + /// + /// Handles a 401 Unauthorized response from a resource. + /// + /// The authentication scheme that was used when the unauthorized response was received. + /// The HTTP response that contained the 401 status code. + /// A token to cancel the operation. + /// + /// A result object indicating if the provider was able to handle the unauthorized response, + /// and the authentication scheme that should be used for the next attempt, if any. + /// + public async Task HandleUnauthorizedResponseAsync( + string scheme, + HttpResponseMessage response, + CancellationToken cancellationToken = default) + { + // This provider only supports Bearer scheme + if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException("This credential provider only supports the Bearer scheme"); + } + + await PerformOAuthAuthorizationAsync(response, cancellationToken).ConfigureAwait(false); + } + + /// + /// Performs OAuth authorization by selecting an appropriate authorization server and completing the OAuth flow. + /// + /// The 401 Unauthorized response containing authentication challenge. + /// Cancellation token. + /// Result indicating whether authorization was successful. + private async Task PerformOAuthAuthorizationAsync( + HttpResponseMessage response, + CancellationToken cancellationToken) + { + // Get available authorization servers from the 401 response + var protectedResourceMetadata = await ExtractProtectedResourceMetadata(response, _serverUrl, cancellationToken).ConfigureAwait(false); + var availableAuthorizationServers = protectedResourceMetadata.AuthorizationServers; + + if (availableAuthorizationServers.Count == 0) + { + ThrowFailedToHandleUnauthorizedResponse("No authorization servers found in authentication challenge"); + } + + // Select authorization server using configured strategy + var selectedAuthServer = _authServerSelector(availableAuthorizationServers); + + if (selectedAuthServer is null) + { + ThrowFailedToHandleUnauthorizedResponse($"Authorization server selection returned null. Available servers: {string.Join(", ", availableAuthorizationServers)}"); + } + + if (!availableAuthorizationServers.Contains(selectedAuthServer)) + { + ThrowFailedToHandleUnauthorizedResponse($"Authorization server selector returned a server not in the available list: {selectedAuthServer}. Available servers: {string.Join(", ", availableAuthorizationServers)}"); + } + + LogSelectedAuthorizationServer(selectedAuthServer, availableAuthorizationServers.Count); + + // Get auth server metadata + var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken).ConfigureAwait(false); + + if (authServerMetadata is null) + { + ThrowFailedToHandleUnauthorizedResponse($"Failed to retrieve metadata for authorization server: '{selectedAuthServer}'"); + } + + // Store auth server metadata for future refresh operations + _authServerMetadata = authServerMetadata; + + // Perform dynamic client registration if needed + if (string.IsNullOrEmpty(_clientId)) + { + await PerformDynamicClientRegistrationAsync(authServerMetadata, cancellationToken).ConfigureAwait(false); + } + + // Perform the OAuth flow + var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + + if (token is null) + { + ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); + } + + _token = token; + LogOAuthAuthorizationCompleted(); + } + + private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) + { + if (!authServerUri.OriginalString.EndsWith("/")) + { + authServerUri = new Uri(authServerUri.OriginalString + "/"); + } + + foreach (var path in new[] { ".well-known/openid-configuration", ".well-known/oauth-authorization-server" }) + { + try + { + var response = await _httpClient.GetAsync(new Uri(authServerUri, path), cancellationToken).ConfigureAwait(false); + if (!response.IsSuccessStatusCode) + { + continue; + } + + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var metadata = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata, cancellationToken).ConfigureAwait(false); + + if (metadata != null) + { + metadata.ResponseTypesSupported ??= ["code"]; + metadata.GrantTypesSupported ??= ["authorization_code", "refresh_token"]; + metadata.TokenEndpointAuthMethodsSupported ??= ["client_secret_post"]; + metadata.CodeChallengeMethodsSupported ??= ["S256"]; + + return metadata; + } + } + catch (Exception ex) + { + LogErrorFetchingAuthServerMetadata(ex, path); + } + } + + return null; + } + + private async Task RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) + { + var requestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "refresh_token", + ["refresh_token"] = refreshToken, + ["client_id"] = GetClientIdOrThrow(), + ["client_secret"] = _clientSecret ?? string.Empty, + ["resource"] = resourceUri.ToString(), + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + { + Content = requestContent + }; + + return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task InitiateAuthorizationCodeFlowAsync( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + CancellationToken cancellationToken) + { + var codeVerifier = GenerateCodeVerifier(); + var codeChallenge = GenerateCodeChallenge(codeVerifier); + + var authUrl = BuildAuthorizationUrl(protectedResourceMetadata, authServerMetadata, codeChallenge); + var authCode = await _authorizationRedirectDelegate(authUrl, _redirectUri, cancellationToken).ConfigureAwait(false); + + if (string.IsNullOrEmpty(authCode)) + { + return null; + } + + return await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false); + } + + private Uri BuildAuthorizationUrl( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + string codeChallenge) + { + if (authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttp && + authServerMetadata.AuthorizationEndpoint.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException("AuthorizationEndpoint must use HTTP or HTTPS.", nameof(authServerMetadata)); + } + + var queryParamsDictionary = new Dictionary + { + ["client_id"] = GetClientIdOrThrow(), + ["redirect_uri"] = _redirectUri.ToString(), + ["response_type"] = "code", + ["code_challenge"] = codeChallenge, + ["code_challenge_method"] = "S256", + ["resource"] = protectedResourceMetadata.Resource.ToString(), + }; + + var scopesSupported = protectedResourceMetadata.ScopesSupported; + if (_scopes is not null || scopesSupported.Count > 0) + { + queryParamsDictionary["scope"] = string.Join(" ", _scopes ?? scopesSupported.ToArray()); + } + + // Add extra parameters if provided. Load into a dictionary before constructing to avoid overwiting values. + foreach (var kvp in _additionalAuthorizationParameters) + { + queryParamsDictionary.Add(kvp.Key, kvp.Value); + } + + var queryParams = HttpUtility.ParseQueryString(string.Empty); + foreach (var kvp in queryParamsDictionary) + { + queryParams[kvp.Key] = kvp.Value; + } + + var uriBuilder = new UriBuilder(authServerMetadata.AuthorizationEndpoint) + { + Query = queryParams.ToString() + }; + + return uriBuilder.Uri; + } + + private async Task ExchangeCodeForTokenAsync( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + string authorizationCode, + string codeVerifier, + CancellationToken cancellationToken) + { + var requestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "authorization_code", + ["code"] = authorizationCode, + ["redirect_uri"] = _redirectUri.ToString(), + ["client_id"] = GetClientIdOrThrow(), + ["code_verifier"] = codeVerifier, + ["client_secret"] = _clientSecret ?? string.Empty, + ["resource"] = protectedResourceMetadata.Resource.ToString(), + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + { + Content = requestContent + }; + + return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false); + + if (tokenResponse is null) + { + ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response."); + } + + tokenResponse.ObtainedAt = DateTimeOffset.UtcNow; + return tokenResponse; + } + + /// + /// Fetches the protected resource metadata from the provided URL. + /// + /// The URL to fetch the metadata from. + /// A token to cancel the operation. + /// The fetched ProtectedResourceMetadata, or null if it couldn't be fetched. + private async Task FetchProtectedResourceMetadataAsync(Uri metadataUrl, CancellationToken cancellationToken = default) + { + using var httpResponse = await _httpClient.GetAsync(metadataUrl, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + return await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.ProtectedResourceMetadata, cancellationToken).ConfigureAwait(false); + } + + /// + /// Performs dynamic client registration with the authorization server. + /// + /// The authorization server metadata. + /// Cancellation token. + /// A task representing the asynchronous operation. + private async Task PerformDynamicClientRegistrationAsync( + AuthorizationServerMetadata authServerMetadata, + CancellationToken cancellationToken) + { + if (authServerMetadata.RegistrationEndpoint is null) + { + ThrowFailedToHandleUnauthorizedResponse("Authorization server does not support dynamic client registration"); + } + + LogPerformingDynamicClientRegistration(authServerMetadata.RegistrationEndpoint); + + var registrationRequest = new DynamicClientRegistrationRequest + { + RedirectUris = [_redirectUri.ToString()], + GrantTypes = ["authorization_code", "refresh_token"], + ResponseTypes = ["code"], + TokenEndpointAuthMethod = "client_secret_post", + ClientName = _clientName, + ClientUri = _clientUri?.ToString(), + Scope = _scopes is not null ? string.Join(" ", _scopes) : null + }; + + var requestJson = JsonSerializer.Serialize(registrationRequest, McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationRequest); + using var requestContent = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.RegistrationEndpoint) + { + Content = requestContent + }; + + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + + if (!httpResponse.IsSuccessStatusCode) + { + var errorContent = await httpResponse.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + ThrowFailedToHandleUnauthorizedResponse($"Dynamic client registration failed with status {httpResponse.StatusCode}: {errorContent}"); + } + + using var responseStream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var registrationResponse = await JsonSerializer.DeserializeAsync( + responseStream, + McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationResponse, + cancellationToken).ConfigureAwait(false); + + if (registrationResponse is null) + { + ThrowFailedToHandleUnauthorizedResponse("Dynamic client registration returned empty response"); + } + + // Update client credentials + _clientId = registrationResponse.ClientId; + if (!string.IsNullOrEmpty(registrationResponse.ClientSecret)) + { + _clientSecret = registrationResponse.ClientSecret; + } + + LogDynamicClientRegistrationSuccessful(_clientId!); + } + + /// + /// Verifies that the resource URI in the metadata exactly matches the original request URL as required by the RFC. + /// Per RFC: The resource value must be identical to the URL that the client used to make the request to the resource server. + /// + /// The metadata to verify. + /// The original URL the client used to make the request to the resource server. + /// True if the resource URI exactly matches the original request URL, otherwise false. + private static bool VerifyResourceMatch(ProtectedResourceMetadata protectedResourceMetadata, Uri resourceLocation) + { + if (protectedResourceMetadata.Resource == null || resourceLocation == null) + { + return false; + } + + // Per RFC: The resource value must be identical to the URL that the client used + // to make the request to the resource server. Compare entire URIs, not just the host. + + // Normalize the URIs to ensure consistent comparison + string normalizedMetadataResource = NormalizeUri(protectedResourceMetadata.Resource); + string normalizedResourceLocation = NormalizeUri(resourceLocation); + + return string.Equals(normalizedMetadataResource, normalizedResourceLocation, StringComparison.OrdinalIgnoreCase); + } + + /// + /// Normalizes a URI for consistent comparison. + /// + /// The URI to normalize. + /// A normalized string representation of the URI. + private static string NormalizeUri(Uri uri) + { + var builder = new UriBuilder(uri) + { + Port = -1 // Always remove port + }; + + if (builder.Path == "/") + { + builder.Path = string.Empty; + } + else if (builder.Path.Length > 1 && builder.Path.EndsWith("/")) + { + builder.Path = builder.Path.TrimEnd('/'); + } + + return builder.Uri.ToString(); + } + + /// + /// Responds to a 401 challenge by parsing the WWW-Authenticate header, fetching the resource metadata, + /// verifying the resource match, and returning the metadata if valid. + /// + /// The HTTP response containing the WWW-Authenticate header. + /// The server URL to verify against the resource metadata. + /// A token to cancel the operation. + /// The resource metadata if the resource matches the server, otherwise throws an exception. + /// Thrown when the response is not a 401, lacks a WWW-Authenticate header, + /// lacks a resource_metadata parameter, the metadata can't be fetched, or the resource URI doesn't match the server URL. + private async Task ExtractProtectedResourceMetadata(HttpResponseMessage response, Uri serverUrl, CancellationToken cancellationToken = default) + { + if (response.StatusCode != System.Net.HttpStatusCode.Unauthorized) + { + throw new InvalidOperationException($"Expected a 401 Unauthorized response, but received {(int)response.StatusCode} {response.StatusCode}"); + } + + // Extract the WWW-Authenticate header + if (response.Headers.WwwAuthenticate.Count == 0) + { + throw new McpException("The 401 response does not contain a WWW-Authenticate header"); + } + + // Look for the Bearer authentication scheme with resource_metadata parameter + string? resourceMetadataUrl = null; + foreach (var header in response.Headers.WwwAuthenticate) + { + if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter)) + { + resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata"); + if (resourceMetadataUrl != null) + { + break; + } + } + } + + if (resourceMetadataUrl == null) + { + throw new McpException("The WWW-Authenticate header does not contain a resource_metadata parameter"); + } + + Uri metadataUri = new(resourceMetadataUrl); + var metadata = await FetchProtectedResourceMetadataAsync(metadataUri, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"Failed to fetch resource metadata from {resourceMetadataUrl}"); + + // Per RFC: The resource value must be identical to the URL that the client used + // to make the request to the resource server + LogValidatingResourceMetadata(serverUrl); + + if (!VerifyResourceMatch(metadata, serverUrl)) + { + throw new McpException($"Resource URI in metadata ({metadata.Resource}) does not match the expected URI ({serverUrl})"); + } + + return metadata; + } + + /// + /// Parses the WWW-Authenticate header parameters to extract a specific parameter. + /// + /// The parameter string from the WWW-Authenticate header. + /// The name of the parameter to extract. + /// The value of the parameter, or null if not found. + private static string? ParseWwwAuthenticateParameters(string parameters, string parameterName) + { + if (parameters.IndexOf(parameterName, StringComparison.OrdinalIgnoreCase) == -1) + { + return null; + } + + foreach (var part in parameters.Split(',')) + { + string trimmedPart = part.Trim(); + int equalsIndex = trimmedPart.IndexOf('='); + + if (equalsIndex <= 0) + { + continue; + } + + string key = trimmedPart.Substring(0, equalsIndex).Trim(); + + if (string.Equals(key, parameterName, StringComparison.OrdinalIgnoreCase)) + { + string value = trimmedPart.Substring(equalsIndex + 1).Trim(); + + if (value.StartsWith("\"") && value.EndsWith("\"")) + { + value = value.Substring(1, value.Length - 2); + } + + return value; + } + } + + return null; + } + + private static string GenerateCodeVerifier() + { + var bytes = new byte[32]; + using var rng = RandomNumberGenerator.Create(); + rng.GetBytes(bytes); + return Convert.ToBase64String(bytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + + private static string GenerateCodeChallenge(string codeVerifier) + { + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + return Convert.ToBase64String(challengeBytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + } + + private string GetClientIdOrThrow() => _clientId ?? throw new InvalidOperationException("Client ID is not available. This may indicate an issue with dynamic client registration."); + + private static void ThrowIfNotBearerScheme(string scheme) + { + if (!string.Equals(scheme, BearerScheme, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException($"The '{scheme}' is not supported. This credential provider only supports the '{BearerScheme}' scheme"); + } + } + + [DoesNotReturn] + private static void ThrowFailedToHandleUnauthorizedResponse(string message) => + throw new McpException($"Failed to handle unauthorized response with 'Bearer' scheme. {message}"); + + [LoggerMessage(Level = LogLevel.Information, Message = "Selected authorization server: {Server} from {Count} available servers")] + partial void LogSelectedAuthorizationServer(Uri server, int count); + + [LoggerMessage(Level = LogLevel.Information, Message = "OAuth authorization completed successfully")] + partial void LogOAuthAuthorizationCompleted(); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error fetching auth server metadata from {Path}")] + partial void LogErrorFetchingAuthServerMetadata(Exception ex, string path); + + [LoggerMessage(Level = LogLevel.Information, Message = "Performing dynamic client registration with {RegistrationEndpoint}")] + partial void LogPerformingDynamicClientRegistration(Uri registrationEndpoint); + + [LoggerMessage(Level = LogLevel.Information, Message = "Dynamic client registration successful. Client ID: {ClientId}")] + partial void LogDynamicClientRegistrationSuccessful(string clientId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Validating resource metadata against original server URL: {ServerUrl}")] + partial void LogValidatingResourceMetadata(Uri serverUrl); +} diff --git a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationRequest.cs b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationRequest.cs new file mode 100644 index 000000000..8496610e7 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationRequest.cs @@ -0,0 +1,51 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a client registration request for OAuth 2.0 Dynamic Client Registration (RFC 7591). +/// +internal sealed class DynamicClientRegistrationRequest +{ + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public required string[] RedirectUris { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + public string[]? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + public string[]? ResponseTypes { get; init; } + + /// + /// Gets or sets the human-readable name of the client. + /// + [JsonPropertyName("client_name")] + public string? ClientName { get; init; } + + /// + /// Gets or sets the URL of the client's home page. + /// + [JsonPropertyName("client_uri")] + public string? ClientUri { get; init; } + + /// + /// Gets or sets the scope values that the client will use. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs new file mode 100644 index 000000000..dcd51d68a --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/DynamicClientRegistrationResponse.cs @@ -0,0 +1,57 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a client registration response for OAuth 2.0 Dynamic Client Registration (RFC 7591). +/// +internal sealed class DynamicClientRegistrationResponse +{ + /// + /// Gets or sets the client identifier. + /// + [JsonPropertyName("client_id")] + public required string ClientId { get; init; } + + /// + /// Gets or sets the client secret. + /// + [JsonPropertyName("client_secret")] + public string? ClientSecret { get; init; } + + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public string[]? RedirectUris { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + public string[]? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + public string[]? ResponseTypes { get; init; } + + /// + /// Gets or sets the client ID issued timestamp. + /// + [JsonPropertyName("client_id_issued_at")] + public long? ClientIdIssuedAt { get; init; } + + /// + /// Gets or sets the client secret expiration time. + /// + [JsonPropertyName("client_secret_expires_at")] + public long? ClientSecretExpiresAt { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs b/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs new file mode 100644 index 000000000..88b5bcc08 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ProtectedResourceMetadata.cs @@ -0,0 +1,145 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents the resource metadata for OAuth authorization as defined in RFC 9396. +/// Defined by RFC 9728. +/// +public sealed class ProtectedResourceMetadata +{ + /// + /// The resource URI. + /// + /// + /// REQUIRED. The protected resource's resource identifier. + /// + [JsonPropertyName("resource")] + public required Uri Resource { get; set; } + + /// + /// The list of authorization server URIs. + /// + /// + /// OPTIONAL. JSON array containing a list of OAuth authorization server issuer identifiers + /// for authorization servers that can be used with this protected resource. + /// + [JsonPropertyName("authorization_servers")] + public List AuthorizationServers { get; set; } = []; + + /// + /// The supported bearer token methods. + /// + /// + /// OPTIONAL. JSON array containing a list of the supported methods of sending an OAuth 2.0 bearer token + /// to the protected resource. Defined values are ["header", "body", "query"]. + /// + [JsonPropertyName("bearer_methods_supported")] + public List BearerMethodsSupported { get; set; } = ["header"]; + + /// + /// The supported scopes. + /// + /// + /// RECOMMENDED. JSON array containing a list of scope values that are used in authorization + /// requests to request access to this protected resource. + /// + [JsonPropertyName("scopes_supported")] + public List ScopesSupported { get; set; } = []; + + /// + /// URL of the protected resource's JSON Web Key (JWK) Set document. + /// + /// + /// OPTIONAL. This contains public keys belonging to the protected resource, such as signing key(s) + /// that the resource server uses to sign resource responses. This URL MUST use the https scheme. + /// + [JsonPropertyName("jwks_uri")] + public Uri? JwksUri { get; set; } + + /// + /// List of the JWS signing algorithms supported by the protected resource for signing resource responses. + /// + /// + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms (alg values) supported by the protected resource + /// for signing resource responses. No default algorithms are implied if this entry is omitted. The value none MUST NOT be used. + /// + [JsonPropertyName("resource_signing_alg_values_supported")] + public List? ResourceSigningAlgValuesSupported { get; set; } + + /// + /// Human-readable name of the protected resource intended for display to the end user. + /// + /// + /// RECOMMENDED. It is recommended that protected resource metadata include this field. + /// The value of this field MAY be internationalized. + /// + [JsonPropertyName("resource_name")] + public string? ResourceName { get; set; } + + /// + /// The URI to the resource documentation. + /// + /// + /// OPTIONAL. URL of a page containing human-readable information that developers might want or need to know + /// when using the protected resource. + /// + [JsonPropertyName("resource_documentation")] + public Uri? ResourceDocumentation { get; set; } + + /// + /// URL of a page containing human-readable information about the protected resource's requirements. + /// + /// + /// OPTIONAL. Information about how the client can use the data provided by the protected resource. + /// + [JsonPropertyName("resource_policy_uri")] + public Uri? ResourcePolicyUri { get; set; } + + /// + /// URL of a page containing human-readable information about the protected resource's terms of service. + /// + /// + /// OPTIONAL. The value of this field MAY be internationalized. + /// + [JsonPropertyName("resource_tos_uri")] + public Uri? ResourceTosUri { get; set; } + + /// + /// Boolean value indicating protected resource support for mutual-TLS client certificate-bound access tokens. + /// + /// + /// OPTIONAL. If omitted, the default value is false. + /// + [JsonPropertyName("tls_client_certificate_bound_access_tokens")] + public bool? TlsClientCertificateBoundAccessTokens { get; set; } + + /// + /// List of the authorization details type values supported by the resource server. + /// + /// + /// OPTIONAL. JSON array containing a list of the authorization details type values supported by the resource server + /// when the authorization_details request parameter is used. + /// + [JsonPropertyName("authorization_details_types_supported")] + public List? AuthorizationDetailsTypesSupported { get; set; } + + /// + /// List of the JWS algorithm values supported by the resource server for validating DPoP proof JWTs. + /// + /// + /// OPTIONAL. JSON array containing a list of the JWS alg values supported by the resource server + /// for validating Demonstrating Proof of Possession (DPoP) proof JWTs. + /// + [JsonPropertyName("dpop_signing_alg_values_supported")] + public List? DpopSigningAlgValuesSupported { get; set; } + + /// + /// Boolean value specifying whether the protected resource always requires the use of DPoP-bound access tokens. + /// + /// + /// OPTIONAL. If omitted, the default value is false. + /// + [JsonPropertyName("dpop_bound_access_tokens_required")] + public bool? DpopBoundAccessTokensRequired { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs new file mode 100644 index 000000000..dc55292b9 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -0,0 +1,57 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a token response from the OAuth server. +/// +internal sealed class TokenContainer +{ + /// + /// Gets or sets the access token. + /// + [JsonPropertyName("access_token")] + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + [JsonPropertyName("expires_in")] + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + [JsonPropertyName("ext_expires_in")] + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + [JsonPropertyName("token_type")] + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + [JsonPropertyName("scope")] + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + [JsonIgnore] + public DateTimeOffset ObtainedAt { get; set; } + + /// + /// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn. + /// + [JsonIgnore] + public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn); +} diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 39ae7e81d..06f2e0bfb 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -13,13 +13,13 @@ namespace ModelContextProtocol.Client; internal sealed partial class AutoDetectingClientSessionTransport : ITransport { private readonly SseClientTransportOptions _options; - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly ILoggerFactory? _loggerFactory; private readonly ILogger _logger; private readonly string _name; private readonly Channel _messageChannel; - public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) + public AutoDetectingClientSessionTransport(string endpointName, SseClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/McpClientTool.cs b/src/ModelContextProtocol.Core/Client/McpClientTool.cs index a3cf7a46b..1810e9c56 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientTool.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientTool.cs @@ -77,6 +77,9 @@ internal McpClientTool( /// public override JsonElement JsonSchema => ProtocolTool.InputSchema; + /// + public override JsonElement? ReturnJsonSchema => ProtocolTool.OutputSchema; + /// public override JsonSerializerOptions JsonSerializerOptions { get; } diff --git a/src/ModelContextProtocol.Core/Client/McpHttpClient.cs b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs new file mode 100644 index 000000000..77ca78fb4 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs @@ -0,0 +1,42 @@ +using ModelContextProtocol.Protocol; +using System.Diagnostics; + +#if NET +using System.Net.Http.Json; +#else +using System.Text; +using System.Text.Json; +#endif + +namespace ModelContextProtocol.Client; + +internal class McpHttpClient(HttpClient httpClient) +{ + internal virtual async Task SendAsync(HttpRequestMessage request, JsonRpcMessage? message, CancellationToken cancellationToken) + { + Debug.Assert(request.Content is null, "The request body should only be supplied as a JsonRpcMessage"); + Debug.Assert(message is null || request.Method == HttpMethod.Post, "All messages should be sent in POST requests."); + + using var content = CreatePostBodyContent(message); + request.Content = content; + return await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + } + + private HttpContent? CreatePostBodyContent(JsonRpcMessage? message) + { + if (message is null) + { + return null; + } + +#if NET + return JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); +#else + return new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json" + ); +#endif + } +} diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 93559b7db..aba7bbcfb 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Net.Http.Headers; using System.Net.ServerSentEvents; -using System.Text; using System.Text.Json; using System.Threading.Channels; @@ -15,7 +14,7 @@ namespace ModelContextProtocol.Client; /// internal sealed partial class SseClientSessionTransport : TransportBase { - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; @@ -31,7 +30,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase public SseClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, - HttpClient httpClient, + McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) : base(endpointName, messageChannel, loggerFactory) @@ -74,12 +73,6 @@ public override async Task SendMessageAsync( if (_messageEndpoint == null) throw new InvalidOperationException("Transport not connected"); - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ); - string messageId = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) @@ -87,12 +80,9 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) - { - Content = content, - }; + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -154,11 +144,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - using var response = await _httpClient.SendAsync( - request, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); + using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientTransport.cs index 3fba349b5..b31c3479b 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientTransport.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; @@ -15,9 +16,10 @@ namespace ModelContextProtocol.Client; public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { private readonly SseClientTransportOptions _options; - private readonly HttpClient _httpClient; + private readonly McpHttpClient _mcpHttpClient; private readonly ILoggerFactory? _loggerFactory; - private readonly bool _ownsHttpClient; + + private readonly HttpClient? _ownedHttpClient; /// /// Initializes a new instance of the class. @@ -45,10 +47,23 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient Throw.IfNull(httpClient); _options = transportOptions; - _httpClient = httpClient; _loggerFactory = loggerFactory; - _ownsHttpClient = ownsHttpClient; Name = transportOptions.Name ?? transportOptions.Endpoint.ToString(); + + if (transportOptions.OAuth is { } clientOAuthOptions) + { + var oAuthProvider = new ClientOAuthProvider(_options.Endpoint, clientOAuthOptions, httpClient, loggerFactory); + _mcpHttpClient = new AuthenticatingMcpHttpClient(httpClient, oAuthProvider); + } + else + { + _mcpHttpClient = new(httpClient); + } + + if (ownsHttpClient) + { + _ownedHttpClient = httpClient; + } } /// @@ -59,8 +74,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = { return _options.TransportMode switch { - HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(_options, _httpClient, _loggerFactory, Name), - HttpTransportMode.StreamableHttp => new StreamableHttpClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory), + HttpTransportMode.AutoDetect => new AutoDetectingClientSessionTransport(Name, _options, _mcpHttpClient, _loggerFactory), + HttpTransportMode.StreamableHttp => new StreamableHttpClientSessionTransport(Name, _options, _mcpHttpClient, messageChannel: null, _loggerFactory), HttpTransportMode.Sse => await ConnectSseTransportAsync(cancellationToken).ConfigureAwait(false), _ => throw new InvalidOperationException($"Unsupported transport mode: {_options.TransportMode}"), }; @@ -68,7 +83,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = private async Task ConnectSseTransportAsync(CancellationToken cancellationToken) { - var sessionTransport = new SseClientSessionTransport(Name, _options, _httpClient, messageChannel: null, _loggerFactory); + var sessionTransport = new SseClientSessionTransport(Name, _options, _mcpHttpClient, messageChannel: null, _loggerFactory); try { @@ -85,11 +100,7 @@ private async Task ConnectSseTransportAsync(CancellationToken cancel /// public ValueTask DisposeAsync() { - if (_ownsHttpClient) - { - _httpClient.Dispose(); - } - + _ownedHttpClient?.Dispose(); return default; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs index 9b4af6db5..4097844cf 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs @@ -1,3 +1,5 @@ +using ModelContextProtocol.Authentication; + namespace ModelContextProtocol.Client; /// @@ -46,7 +48,7 @@ public required Uri Endpoint public HttpTransportMode TransportMode { get; set; } = HttpTransportMode.AutoDetect; /// - /// Gets a transport identifier used for logging purposes. + /// Gets or sets a transport identifier used for logging purposes. /// public string? Name { get; set; } @@ -70,4 +72,9 @@ public required Uri Endpoint /// Use this property to specify custom HTTP headers that should be sent with each request to the server. /// public IDictionary? AdditionalHeaders { get; set; } + + /// + /// Gets sor sets the authorization provider to use for authentication. + /// + public ClientOAuthOptions? OAuth { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 14df5c353..190bec0b2 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -6,12 +6,6 @@ using ModelContextProtocol.Protocol; using System.Threading.Channels; -#if NET -using System.Net.Http.Json; -#else -using System.Text; -#endif - namespace ModelContextProtocol.Client; /// @@ -22,7 +16,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); - private readonly HttpClient _httpClient; + private readonly McpHttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; private readonly ILogger _logger; @@ -36,7 +30,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa public StreamableHttpClientSessionTransport( string endpointName, SseClientTransportOptions transportOptions, - HttpClient httpClient, + McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) : base(endpointName, messageChannel, loggerFactory) @@ -69,19 +63,8 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token); cancellationToken = sendCts.Token; -#if NET - using var content = JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); -#else - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json; charset=utf-8" - ); -#endif - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _options.Endpoint) { - Content = content, Headers = { Accept = { s_applicationJsonMediaType, s_textEventStreamMediaType }, @@ -90,7 +73,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false); // We'll let the caller decide whether to throw or fall back given an unsuccessful response. if (!response.IsSuccessStatusCode) @@ -192,7 +175,7 @@ private async Task ReceiveUnsolicitedMessagesAsync() request.Headers.Accept.Add(s_textEventStreamMediaType); CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); - using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false); + using var response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); if (!response.IsSuccessStatusCode) { @@ -261,7 +244,7 @@ private async Task SendDeleteRequest() try { // Do not validate we get a successful status code, because server support for the DELETE request is optional - (await _httpClient.SendAsync(deleteRequest, CancellationToken.None).ConfigureAwait(false)).Dispose(); + (await _httpClient.SendAsync(deleteRequest, message: null, CancellationToken.None).ConfigureAwait(false)).Dispose(); } catch (Exception ex) { @@ -310,4 +293,4 @@ internal static void CopyAdditionalHeaders( } } } -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index 696e0ec05..21e2468d9 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.AI; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; using System.Diagnostics.CodeAnalysis; using System.Text.Json; @@ -154,6 +155,12 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(IReadOnlyDictionary))] [JsonSerializable(typeof(ProgressToken))] + [JsonSerializable(typeof(ProtectedResourceMetadata))] + [JsonSerializable(typeof(AuthorizationServerMetadata))] + [JsonSerializable(typeof(TokenContainer))] + [JsonSerializable(typeof(DynamicClientRegistrationRequest))] + [JsonSerializable(typeof(DynamicClientRegistrationResponse))] + // Primitive types for use in consuming AIFunctions [JsonSerializable(typeof(string))] [JsonSerializable(typeof(byte))] diff --git a/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs b/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs index 7438522cc..5d4750aa2 100644 --- a/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs +++ b/src/ModelContextProtocol.Core/Protocol/CallToolResult.cs @@ -44,5 +44,5 @@ public sealed class CallToolResult : Result /// and potentially self-correct in subsequent requests. /// [JsonPropertyName("isError")] - public bool IsError { get; set; } + public bool? IsError { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs index db86c7b66..516ea2446 100644 --- a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs +++ b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs @@ -149,7 +149,7 @@ public class Converter : JsonConverter Meta = meta, }, - "image" => new ImageContentBlock() + "image" => new ImageContentBlock { Data = data ?? throw new JsonException("Image data must be provided for 'image' type."), MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'image' type."), @@ -157,7 +157,7 @@ public class Converter : JsonConverter Meta = meta, }, - "audio" => new AudioContentBlock() + "audio" => new AudioContentBlock { Data = data ?? throw new JsonException("Audio data must be provided for 'audio' type."), MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'audio' type."), @@ -165,14 +165,14 @@ public class Converter : JsonConverter Meta = meta, }, - "resource" => new EmbeddedResourceBlock() + "resource" => new EmbeddedResourceBlock { Resource = resource ?? throw new JsonException("Resource contents must be provided for 'resource' type."), Annotations = annotations, Meta = meta, }, - "resource_link" => new ResourceLinkBlock() + "resource_link" => new ResourceLinkBlock { Uri = uri ?? throw new JsonException("URI must be provided for 'resource_link' type."), Name = name ?? throw new JsonException("Name must be provided for 'resource_link' type."), diff --git a/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs index 05d8a49ae..3a9926e22 100644 --- a/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs @@ -1,4 +1,7 @@ +using System.ComponentModel; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Text.Json; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -54,39 +57,273 @@ public IDictionary Properties public IList? Required { get; set; } } - /// /// Represents restricted subset of JSON Schema: /// , , , or . /// - [JsonDerivedType(typeof(BooleanSchema))] - [JsonDerivedType(typeof(EnumSchema))] - [JsonDerivedType(typeof(NumberSchema))] - [JsonDerivedType(typeof(StringSchema))] + [JsonConverter(typeof(Converter))] // TODO: This converter exists due to the lack of downlevel support for AllowOutOfOrderMetadataProperties. public abstract class PrimitiveSchemaDefinition { /// Prevent external derivations. protected private PrimitiveSchemaDefinition() { } - } - /// Represents a schema for a string type. - public sealed class StringSchema : PrimitiveSchemaDefinition - { /// Gets the type of the schema. - /// This is always "string". [JsonPropertyName("type")] - public string Type => "string"; + public abstract string Type { get; set; } - /// Gets or sets a title for the string. + /// Gets or sets a title for the schema. [JsonPropertyName("title")] public string? Title { get; set; } - /// Gets or sets a description for the string. + /// Gets or sets a description for the schema. [JsonPropertyName("description")] public string? Description { get; set; } + /// + /// Provides a for . + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public class Converter : JsonConverter + { + /// + public override PrimitiveSchemaDefinition? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException(); + } + + string? type = null; + string? title = null; + string? description = null; + int? minLength = null; + int? maxLength = null; + string? format = null; + double? minimum = null; + double? maximum = null; + bool? defaultBool = null; + IList? enumValues = null; + IList? enumNames = null; + + while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) + { + if (reader.TokenType != JsonTokenType.PropertyName) + { + continue; + } + + string? propertyName = reader.GetString(); + bool success = reader.Read(); + Debug.Assert(success, "STJ must have buffered the entire object for us."); + + switch (propertyName) + { + case "type": + type = reader.GetString(); + break; + + case "title": + title = reader.GetString(); + break; + + case "description": + description = reader.GetString(); + break; + + case "minLength": + minLength = reader.GetInt32(); + break; + + case "maxLength": + maxLength = reader.GetInt32(); + break; + + case "format": + format = reader.GetString(); + break; + + case "minimum": + minimum = reader.GetDouble(); + break; + + case "maximum": + maximum = reader.GetDouble(); + break; + + case "default": + defaultBool = reader.GetBoolean(); + break; + + case "enum": + enumValues = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.IListString); + break; + + case "enumNames": + enumNames = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.IListString); + break; + + default: + break; + } + } + + if (type is null) + { + throw new JsonException("The 'type' property is required."); + } + + PrimitiveSchemaDefinition? psd = null; + switch (type) + { + case "string": + if (enumValues is not null) + { + psd = new EnumSchema + { + Enum = enumValues, + EnumNames = enumNames + }; + } + else + { + psd = new StringSchema + { + MinLength = minLength, + MaxLength = maxLength, + Format = format, + }; + } + break; + + case "integer": + case "number": + psd = new NumberSchema + { + Minimum = minimum, + Maximum = maximum, + }; + break; + + case "boolean": + psd = new BooleanSchema + { + Default = defaultBool, + }; + break; + } + + if (psd is not null) + { + psd.Type = type; + psd.Title = title; + psd.Description = description; + } + + return psd; + } + + /// + public override void Write(Utf8JsonWriter writer, PrimitiveSchemaDefinition value, JsonSerializerOptions options) + { + if (value is null) + { + writer.WriteNullValue(); + return; + } + + writer.WriteStartObject(); + + writer.WriteString("type", value.Type); + if (value.Title is not null) + { + writer.WriteString("title", value.Title); + } + if (value.Description is not null) + { + writer.WriteString("description", value.Description); + } + + switch (value) + { + case StringSchema stringSchema: + if (stringSchema.MinLength.HasValue) + { + writer.WriteNumber("minLength", stringSchema.MinLength.Value); + } + if (stringSchema.MaxLength.HasValue) + { + writer.WriteNumber("maxLength", stringSchema.MaxLength.Value); + } + if (stringSchema.Format is not null) + { + writer.WriteString("format", stringSchema.Format); + } + break; + + case NumberSchema numberSchema: + if (numberSchema.Minimum.HasValue) + { + writer.WriteNumber("minimum", numberSchema.Minimum.Value); + } + if (numberSchema.Maximum.HasValue) + { + writer.WriteNumber("maximum", numberSchema.Maximum.Value); + } + break; + + case BooleanSchema booleanSchema: + if (booleanSchema.Default.HasValue) + { + writer.WriteBoolean("default", booleanSchema.Default.Value); + } + break; + + case EnumSchema enumSchema: + if (enumSchema.Enum is not null) + { + writer.WritePropertyName("enum"); + JsonSerializer.Serialize(writer, enumSchema.Enum, McpJsonUtilities.JsonContext.Default.IListString); + } + if (enumSchema.EnumNames is not null) + { + writer.WritePropertyName("enumNames"); + JsonSerializer.Serialize(writer, enumSchema.EnumNames, McpJsonUtilities.JsonContext.Default.IListString); + } + break; + + default: + throw new JsonException($"Unexpected schema type: {value.GetType().Name}"); + } + + writer.WriteEndObject(); + } + } + } + + /// Represents a schema for a string type. + public sealed class StringSchema : PrimitiveSchemaDefinition + { + /// + [JsonPropertyName("type")] + public override string Type + { + get => "string"; + set + { + if (value is not "string") + { + throw new ArgumentException("Type must be 'string'.", nameof(value)); + } + } + } + /// Gets or sets the minimum length for the string. [JsonPropertyName("minLength")] public int? MinLength @@ -139,11 +376,9 @@ public string? Format /// Represents a schema for a number or integer type. public sealed class NumberSchema : PrimitiveSchemaDefinition { - /// Gets the type of the schema. - /// This should be "number" or "integer". - [JsonPropertyName("type")] + /// [field: MaybeNull] - public string Type + public override string Type { get => field ??= "number"; set @@ -157,14 +392,6 @@ public string Type } } - /// Gets or sets a title for the number input. - [JsonPropertyName("title")] - public string? Title { get; set; } - - /// Gets or sets a description for the number input. - [JsonPropertyName("description")] - public string? Description { get; set; } - /// Gets or sets the minimum allowed value. [JsonPropertyName("minimum")] public double? Minimum { get; set; } @@ -177,18 +404,19 @@ public string Type /// Represents a schema for a Boolean type. public sealed class BooleanSchema : PrimitiveSchemaDefinition { - /// Gets the type of the schema. - /// This is always "boolean". + /// [JsonPropertyName("type")] - public string Type => "boolean"; - - /// Gets or sets a title for the Boolean. - [JsonPropertyName("title")] - public string? Title { get; set; } - - /// Gets or sets a description for the Boolean. - [JsonPropertyName("description")] - public string? Description { get; set; } + public override string Type + { + get => "boolean"; + set + { + if (value is not "boolean") + { + throw new ArgumentException("Type must be 'boolean'.", nameof(value)); + } + } + } /// Gets or sets the default value for the Boolean. [JsonPropertyName("default")] @@ -198,18 +426,19 @@ public sealed class BooleanSchema : PrimitiveSchemaDefinition /// Represents a schema for an enum type. public sealed class EnumSchema : PrimitiveSchemaDefinition { - /// Gets the type of the schema. - /// This is always "string". + /// [JsonPropertyName("type")] - public string Type => "string"; - - /// Gets or sets a title for the enum. - [JsonPropertyName("title")] - public string? Title { get; set; } - - /// Gets or sets a description for the enum. - [JsonPropertyName("description")] - public string? Description { get; set; } + public override string Type + { + get => "string"; + set + { + if (value is not "string") + { + throw new ArgumentException("Type must be 'string'.", nameof(value)); + } + } + } /// Gets or sets the list of allowed string values for the enum. [JsonPropertyName("enum")] diff --git a/src/ModelContextProtocol.Core/Protocol/ProgressNotificationParams.cs b/src/ModelContextProtocol.Core/Protocol/ProgressNotificationParams.cs index 1fd927061..0b661d181 100644 --- a/src/ModelContextProtocol.Core/Protocol/ProgressNotificationParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/ProgressNotificationParams.cs @@ -98,7 +98,7 @@ public sealed class Converter : JsonConverter return new ProgressNotificationParams { ProgressToken = progressToken.GetValueOrDefault(), - Progress = new ProgressNotificationValue() + Progress = new ProgressNotificationValue { Progress = progress.GetValueOrDefault(), Total = total, diff --git a/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs b/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs index 1332a6aa4..f6486488b 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs @@ -103,5 +103,5 @@ public sealed class ResourcesCapability /// /// [JsonIgnore] - public McpServerPrimitiveCollection? ResourceCollection { get; set; } + public McpServerResourceCollection? ResourceCollection { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs index 8d446c58c..d651d7ee3 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using System.ComponentModel; +using System.Diagnostics; using System.Reflection; using System.Text.Json; @@ -57,8 +58,8 @@ internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt return Create( AIFunctionFactory.Create(method, args => { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); + Debug.Assert(args.Services is RequestServiceProvider, $"The service provider should be a {nameof(RequestServiceProvider)} for this method to work correctly."); + return createTargetFunc(((RequestServiceProvider)args.Services!).Request); }, CreateAIFunctionFactoryOptions(method, options)), options); } @@ -67,61 +68,22 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MethodInfo method, McpServerPromptCreateOptions? options) => new() { - Name = options?.Name ?? method.GetCustomAttribute()?.Name, + Name = options?.Name ?? method.GetCustomAttribute()?.Name ?? AIFunctionMcpServerTool.DeriveName(method), Description = options?.Description, MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, JsonSchemaCreateOptions = options?.SchemaCreateOptions, ConfigureParameterBinding = pi => { - if (pi.ParameterType == typeof(RequestContext)) + if (RequestServiceProvider.IsAugmentedWith(pi.ParameterType) || + (options?.Services?.GetService() is { } ispis && + ispis.IsService(pi.ParameterType))) { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } - - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. return new() { ExcludeFromSchema = true, BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; - } - - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + args.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -133,24 +95,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; } return default; - - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && - orc is RequestContext requestContext) - { - return requestContext; - } - - return null; - } }, }; @@ -226,14 +177,10 @@ public override async ValueTask GetAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - AIFunctionArguments arguments = new() - { - Services = request.Services, - Context = new Dictionary() { [typeof(RequestContext)] = request } - }; + request.Services = new RequestServiceProvider(request, request.Services); + AIFunctionArguments arguments = new() { Services = request.Services }; - var argDict = request.Params?.Arguments; - if (argDict is not null) + if (request.Params?.Arguments is { } argDict) { foreach (var kvp in argDict) { diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs index 487fed74c..a8b0d2486 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Protocol; using System.Collections.Concurrent; using System.ComponentModel; +using System.Diagnostics; using System.Globalization; using System.Reflection; using System.Text; @@ -64,8 +65,8 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource return Create( AIFunctionFactory.Create(method, args => { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); + Debug.Assert(args.Services is RequestServiceProvider, $"The service provider should be a {nameof(RequestServiceProvider)} for this method to work correctly."); + return createTargetFunc(((RequestServiceProvider)args.Services!).Request); }, CreateAIFunctionFactoryOptions(method, options)), options); } @@ -74,61 +75,22 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MethodInfo method, McpServerResourceCreateOptions? options) => new() { - Name = options?.Name ?? method.GetCustomAttribute()?.Name, + Name = options?.Name ?? method.GetCustomAttribute()?.Name ?? AIFunctionMcpServerTool.DeriveName(method), Description = options?.Description, MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, JsonSchemaCreateOptions = options?.SchemaCreateOptions, ConfigureParameterBinding = pi => { - if (pi.ParameterType == typeof(RequestContext)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } - - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; - } - - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) + if (RequestServiceProvider.IsAugmentedWith(pi.ParameterType) || + (options?.Services?.GetService() is { } ispis && + ispis.IsService(pi.ParameterType))) { return new() { ExcludeFromSchema = true, BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + args.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -140,7 +102,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -172,17 +134,6 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( } return default; - - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var rc) is true && - rc is RequestContext requestContext) - { - return requestContext; - } - - return null; - } }, }; @@ -264,7 +215,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Name = name, Title = options?.Title, Description = options?.Description, - MimeType = options?.MimeType, + MimeType = options?.MimeType ?? "application/octet-stream", }; return new AIFunctionMcpServerResource(function, resource); @@ -295,7 +246,7 @@ private static string DeriveUriTemplate(string name, AIFunction function) { StringBuilder template = new(); - template.Append("resource://").Append(Uri.EscapeDataString(name)); + template.Append("resource://mcp/").Append(Uri.EscapeDataString(name)); if (function.JsonSchema.TryGetProperty("properties", out JsonElement properties)) { @@ -359,17 +310,14 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour return null; } } - else if (request.Params.Uri != ProtocolResource!.Uri) + else if (!UriTemplate.UriTemplateComparer.Instance.Equals(request.Params.Uri, ProtocolResource!.Uri)) { return null; } // Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI. - AIFunctionArguments arguments = new() - { - Services = request.Services, - Context = new Dictionary() { [typeof(RequestContext)] = request } - }; + request.Services = new RequestServiceProvider(request, request.Services); + AIFunctionArguments arguments = new() { Services = request.Services }; // For templates, populate the arguments from the URI template. if (match is not null) @@ -421,14 +369,14 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour Contents = aiContents.Select( ac => ac switch { - TextContent tc => new TextResourceContents() + TextContent tc => new TextResourceContents { Uri = request.Params!.Uri, MimeType = ProtocolResourceTemplate.MimeType, Text = tc.Text }, - DataContent dc => new BlobResourceContents() + DataContent dc => new BlobResourceContents { Uri = request.Params!.Uri, MimeType = dc.MediaType, @@ -441,7 +389,7 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour IEnumerable strings => new() { - Contents = strings.Select(text => new TextResourceContents() + Contents = strings.Select(text => new TextResourceContents { Uri = request.Params!.Uri, MimeType = ProtocolResourceTemplate.MimeType, diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 39dd19e3e..afd3912b6 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -4,10 +4,11 @@ using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics; using System.Reflection; using System.Text.Json; using System.Text.Json.Nodes; +using System.Text.RegularExpressions; namespace ModelContextProtocol.Server; @@ -64,79 +65,32 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool return Create( AIFunctionFactory.Create(method, args => { - var request = (RequestContext)args.Context![typeof(RequestContext)]!; - return createTargetFunc(request); + Debug.Assert(args.Services is RequestServiceProvider, $"The service provider should be a {nameof(RequestServiceProvider)} for this method to work correctly."); + return createTargetFunc(((RequestServiceProvider)args.Services!).Request); }, CreateAIFunctionFactoryOptions(method, options)), options); } - // TODO: Fix the need for this suppression. - [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2111:ReflectionToDynamicallyAccessedMembers", - Justification = "AIFunctionFactory ensures that the Type passed to AIFunctionFactoryOptions.CreateInstance has public constructors preserved")] - internal static Func GetCreateInstanceFunc() => - static ([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] type, args) => args.Services is { } services ? - ActivatorUtilities.CreateInstance(services, type) : - Activator.CreateInstance(type)!; - private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MethodInfo method, McpServerToolCreateOptions? options) => new() { - Name = options?.Name ?? method.GetCustomAttribute()?.Name, + Name = options?.Name ?? method.GetCustomAttribute()?.Name ?? DeriveName(method), Description = options?.Description, MarshalResult = static (result, _, cancellationToken) => new ValueTask(result), SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions, JsonSchemaCreateOptions = options?.SchemaCreateOptions, ConfigureParameterBinding = pi => { - if (pi.ParameterType == typeof(RequestContext)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args), - }; - } - - if (pi.ParameterType == typeof(IMcpServer)) + if (RequestServiceProvider.IsAugmentedWith(pi.ParameterType) || + (options?.Services?.GetService() is { } ispis && + ispis.IsService(pi.ParameterType))) { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => GetRequestContext(args)?.Server, - }; - } - - if (pi.ParameterType == typeof(IProgress)) - { - // Bind IProgress to the progress token in the request, - // if there is one. If we can't get one, return a nop progress. return new() { ExcludeFromSchema = true, BindParameter = (pi, args) => - { - var requestContent = GetRequestContext(args); - if (requestContent?.Server is { } server && - requestContent?.Params?.ProgressToken is { } progressToken) - { - return new TokenProgress(server, progressToken); - } - - return NullProgress.Instance; - }, - }; - } - - if (options?.Services is { } services && - services.GetService() is { } ispis && - ispis.IsService(pi.ParameterType)) - { - return new() - { - ExcludeFromSchema = true, - BindParameter = (pi, args) => - GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? + args.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; @@ -148,24 +102,13 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( { ExcludeFromSchema = true, BindParameter = (pi, args) => - (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (args?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : throw new ArgumentException("No service of the requested type was found.")), }; } return default; - - static RequestContext? GetRequestContext(AIFunctionArguments args) - { - if (args.Context?.TryGetValue(typeof(RequestContext), out var orc) is true && - orc is RequestContext requestContext) - { - return requestContext; - } - - return null; - } }, }; @@ -268,14 +211,10 @@ public override async ValueTask InvokeAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - AIFunctionArguments arguments = new() - { - Services = request.Services, - Context = new Dictionary() { [typeof(RequestContext)] = request } - }; + request.Services = new RequestServiceProvider(request, request.Services); + AIFunctionArguments arguments = new() { Services = request.Services }; - var argDict = request.Params?.Arguments; - if (argDict is not null) + if (request.Params?.Arguments is { } argDict) { foreach (var kvp in argDict) { @@ -355,6 +294,63 @@ public override async ValueTask InvokeAsync( }; } + /// Creates a name to use based on the supplied method and naming policy. + internal static string DeriveName(MethodInfo method, JsonNamingPolicy? policy = null) + { + string name = method.Name; + + // Remove any "Async" suffix if the method is an async method and if the method name isn't just "Async". + const string AsyncSuffix = "Async"; + if (IsAsyncMethod(method) && + name.EndsWith(AsyncSuffix, StringComparison.Ordinal) && + name.Length > AsyncSuffix.Length) + { + name = name.Substring(0, name.Length - AsyncSuffix.Length); + } + + // Replace anything other than ASCII letters or digits with underscores, trim off any leading or trailing underscores. + name = NonAsciiLetterDigitsRegex().Replace(name, "_").Trim('_'); + + // If after all our transformations the name is empty, just use the original method name. + if (name.Length == 0) + { + name = method.Name; + } + + // Case the name based on the provided naming policy. + return (policy ?? JsonNamingPolicy.SnakeCaseLower).ConvertName(name) ?? name; + + static bool IsAsyncMethod(MethodInfo method) + { + Type t = method.ReturnType; + + if (t == typeof(Task) || t == typeof(ValueTask)) + { + return true; + } + + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; + } + } + + /// Regex that flags runs of characters other than ASCII digits or letters. +#if NET + [GeneratedRegex("[^0-9A-Za-z]+")] + private static partial Regex NonAsciiLetterDigitsRegex(); +#else + private static Regex NonAsciiLetterDigitsRegex() => _nonAsciiLetterDigits; + private static readonly Regex _nonAsciiLetterDigits = new("[^0-9A-Za-z]+", RegexOptions.Compiled); +#endif + private static JsonElement? CreateOutputSchema(AIFunction function, McpServerToolCreateOptions? toolCreateOptions, out bool structuredOutputRequiresWrapping) { structuredOutputRequiresWrapping = false; diff --git a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs b/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs new file mode 100644 index 000000000..3372072fe --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs @@ -0,0 +1,58 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// Augments a service provider with additional request-related services. +internal sealed class RequestServiceProvider( + RequestContext request, IServiceProvider? innerServices) : + IServiceProvider, IKeyedServiceProvider, + IServiceProviderIsService, IServiceProviderIsKeyedService, + IDisposable, IAsyncDisposable + where TRequestParams : RequestParams +{ + /// Gets the request associated with this instance. + public RequestContext Request => request; + + /// Gets whether the specified type is in the list of additional types this service provider wraps around the one in a provided request's services. + public static bool IsAugmentedWith(Type serviceType) => + serviceType == typeof(RequestContext) || + serviceType == typeof(IMcpServer) || + serviceType == typeof(IProgress); + + /// + public object? GetService(Type serviceType) => + serviceType == typeof(RequestContext) ? request : + serviceType == typeof(IMcpServer) ? request.Server : + serviceType == typeof(IProgress) ? + (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : + innerServices?.GetService(serviceType); + + /// + public bool IsService(Type serviceType) => + IsAugmentedWith(serviceType) || + (innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; + + /// + public bool IsKeyedService(Type serviceType, object? serviceKey) => + (serviceKey is null && IsService(serviceType)) || + (innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; + + /// + public object? GetKeyedService(Type serviceType, object? serviceKey) => + serviceKey is null ? GetService(serviceType) : + (innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); + + /// + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => + GetKeyedService(serviceType, serviceKey) ?? + throw new InvalidOperationException($"No service of type '{serviceType}' with key '{serviceKey}' is registered."); + + /// + public void Dispose() => + (innerServices as IDisposable)?.Dispose(); + + /// + public ValueTask DisposeAsync() => + innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 829e0a865..6c5858f91 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -228,7 +228,7 @@ private void ConfigureResources(McpServerOptions options) var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler ?? (static async (_, __) => new EmptyResult()); var resources = resourcesCapability.ResourceCollection; var listChanged = resourcesCapability.ListChanged; - var subcribe = resourcesCapability.Subscribe; + var subscribe = resourcesCapability.Subscribe; // Handle resources provided via DI. if (resources is { IsEmpty: false }) @@ -309,7 +309,7 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure listChanged = true; // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. - // subcribe = true; + // subscribe = true; } ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; @@ -319,7 +319,7 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure ServerCapabilities.Resources.SubscribeToResourcesHandler = subscribeHandler; ServerCapabilities.Resources.UnsubscribeFromResourcesHandler = unsubscribeHandler; ServerCapabilities.Resources.ListChanged = listChanged; - ServerCapabilities.Resources.Subscribe = subcribe; + ServerCapabilities.Resources.Subscribe = subscribe; SetHandler( RequestMethods.ResourcesList, diff --git a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs index 1b435c6a7..d00c41a6b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs @@ -103,12 +103,12 @@ public static async Task SampleAsync( { Role = role, Content = dataContent.HasTopLevelMediaType("image") ? - new ImageContentBlock() + new ImageContentBlock { MimeType = dataContent.MediaType, Data = dataContent.Base64Data.ToString(), } : - new AudioContentBlock() + new AudioContentBlock { MimeType = dataContent.MediaType, Data = dataContent.Base64Data.ToString(), @@ -344,7 +344,7 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except void Log(LogLevel logLevel, string message) { - _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams() + _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams { Level = McpServer.ToLoggingLevel(logLevel), Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs b/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs index 7bfe0232f..f891858eb 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs @@ -10,13 +10,14 @@ public class McpServerPrimitiveCollection : ICollection, IReadOnlyCollecti where T : IMcpServerPrimitive { /// Concurrent dictionary of primitives, indexed by their names. - private readonly ConcurrentDictionary _primitives = []; + private readonly ConcurrentDictionary _primitives; /// /// Initializes a new instance of the class. /// - public McpServerPrimitiveCollection() + public McpServerPrimitiveCollection(IEqualityComparer? keyComparer = null) { + _primitives = new(keyComparer); } /// Occurs when the collection is changed. diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceCollection.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceCollection.cs new file mode 100644 index 000000000..fb5f6b4e2 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceCollection.cs @@ -0,0 +1,5 @@ +namespace ModelContextProtocol.Server; + +/// Provides a thread-safe collection of instances, indexed by their URI templates. +public sealed class McpServerResourceCollection() + : McpServerPrimitiveCollection(UriTemplate.UriTemplateComparer.Instance); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index 97b0a38b0..d4ea9eb75 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -245,8 +245,13 @@ public bool ReadOnly /// Gets or sets whether the tool should report an output schema for structured content. /// /// + /// /// When enabled, the tool will attempt to populate the /// and provide structured content in the property. + /// + /// + /// The default is . + /// /// public bool UseStructuredContent { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs index 63407d882..bdb4ecb8d 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs @@ -129,8 +129,13 @@ public sealed class McpServerToolCreateOptions /// Gets or sets whether the tool should report an output schema for structured content. /// /// + /// /// When enabled, the tool will attempt to populate the /// and provide structured content in the property. + /// + /// + /// The default is . + /// /// public bool UseStructuredContent { get; set; } diff --git a/src/ModelContextProtocol.Core/UriTemplate.cs b/src/ModelContextProtocol.Core/UriTemplate.cs index bc6b70c9f..a822b6f2a 100644 --- a/src/ModelContextProtocol.Core/UriTemplate.cs +++ b/src/ModelContextProtocol.Core/UriTemplate.cs @@ -2,6 +2,7 @@ using System.Buffers; #endif using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Runtime.CompilerServices; using System.Text; @@ -453,4 +454,53 @@ static void AppendHex(ref DefaultInterpolatedStringHandler builder, char c) } } } + + /// + /// Defines an equality comparer for Uri templates as follows: + /// 1. Non-templated Uris use regular System.Uri equality comparison (host name is case insensitive). + /// 2. Templated Uris use regular string equality. + /// + /// We do this because non-templated resources are looked up directly from the resource dictionary + /// and we need to make sure equality is implemented correctly. Templated Uris are resolved in a + /// fallback step using linear traversal of the resource dictionary, so their equality is only + /// there to distinguish between different templates. + /// + public sealed class UriTemplateComparer : IEqualityComparer + { + public static IEqualityComparer Instance { get; } = new UriTemplateComparer(); + + public bool Equals(string? uriTemplate1, string? uriTemplate2) + { + if (TryParseAsNonTemplatedUri(uriTemplate1, out Uri? uri1) && + TryParseAsNonTemplatedUri(uriTemplate2, out Uri? uri2)) + { + return uri1 == uri2; + } + + return string.Equals(uriTemplate1, uriTemplate2, StringComparison.Ordinal); + } + + public int GetHashCode([DisallowNull] string uriTemplate) + { + if (TryParseAsNonTemplatedUri(uriTemplate, out Uri? uri)) + { + return uri.GetHashCode(); + } + else + { + return StringComparer.Ordinal.GetHashCode(uriTemplate); + } + } + + private static bool TryParseAsNonTemplatedUri(string? uriTemplate, [NotNullWhen(true)] out Uri? uri) + { + if (uriTemplate is null || uriTemplate.Contains('{')) + { + uri = null; + return false; + } + + return Uri.TryCreate(uriTemplate, UriKind.Absolute, out uri); + } + } } \ No newline at end of file diff --git a/src/ModelContextProtocol/McpServerOptionsSetup.cs b/src/ModelContextProtocol/McpServerOptionsSetup.cs index effa41463..7fe4f61cb 100644 --- a/src/ModelContextProtocol/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/McpServerOptionsSetup.cs @@ -63,7 +63,7 @@ public void Configure(McpServerOptions options) // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants // change notifications, etc. - McpServerPrimitiveCollection resourceCollection = options.Capabilities?.Resources?.ResourceCollection ?? []; + McpServerResourceCollection resourceCollection = options.Capabilities?.Resources?.ResourceCollection ?? []; foreach (var resource in serverResources) { resourceCollection.TryAdd(resource); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs new file mode 100644 index 000000000..2252b1b7c --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs @@ -0,0 +1,407 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; +using System.Net; +using System.Reflection; +using Xunit.Sdk; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class AuthTests : KestrelInMemoryTest, IAsyncDisposable +{ + private const string McpServerUrl = "http://localhost:5000/"; + private const string OAuthServerUrl = "https://localhost:7029/"; + + private readonly CancellationTokenSource _testCts = new(); + private readonly TestOAuthServer.Program _testOAuthServer; + private readonly Task _testOAuthRunTask; + + private Uri? _lastAuthorizationUri; + + public AuthTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + // Let the HandleAuthorizationUrlAsync take a look at the Location header + SocketsHttpHandler.AllowAutoRedirect = false; + // The dev cert may not be installed on the CI, but AddJwtBearer requires an HTTPS backchannel by default. + // The easiest workaround is to disable cert validation for testing purposes. + SocketsHttpHandler.SslOptions.RemoteCertificateValidationCallback = (_, _, _, _) => true; + + _testOAuthServer = new TestOAuthServer.Program(XunitLoggerProvider, KestrelInMemoryTransport); + _testOAuthRunTask = _testOAuthServer.RunServerAsync(cancellationToken: _testCts.Token); + + Builder.Services.AddAuthentication(options => + { + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + }) + .AddJwtBearer(options => + { + options.Backchannel = HttpClient; + options.Authority = OAuthServerUrl; + options.TokenValidationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidateAudience = true, + ValidateLifetime = true, + ValidateIssuerSigningKey = true, + ValidAudience = McpServerUrl, + ValidIssuer = OAuthServerUrl, + NameClaimType = "name", + RoleClaimType = "roles" + }; + }) + .AddMcp(options => + { + options.ResourceMetadata = new ProtectedResourceMetadata + { + Resource = new Uri(McpServerUrl), + AuthorizationServers = { new Uri(OAuthServerUrl) }, + ScopesSupported = ["mcp:tools"] + }; + }); + + Builder.Services.AddAuthorization(); + } + + public async ValueTask DisposeAsync() + { + _testCts.Cancel(); + try + { + await _testOAuthRunTask; + } + catch (OperationCanceledException) + { + } + finally + { + _testCts.Dispose(); + } + } + + [Fact] + public async Task CanAuthenticate() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task CannotAuthenticate_WithoutOAuthConfiguration() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + }, HttpClient, LoggerFactory); + + var httpEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal(HttpStatusCode.Unauthorized, httpEx.StatusCode); + } + + [Fact] + public async Task CannotAuthenticate_WithUnregisteredClient() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "unregistered-demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, HttpClient, LoggerFactory); + + // The EqualException is thrown by HandleAuthorizationUrlAsync when the /authorize request gets a 400 + var equalEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task CanAuthenticate_WithDynamicClientRegistration() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new ClientOAuthOptions() + { + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + ClientName = "Test MCP Client", + ClientUri = new Uri("https://example.com/"), + Scopes = ["mcp:tools"] + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task CanAuthenticate_WithTokenRefresh() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "test-refresh-client", + ClientSecret = "test-refresh-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + }, + }, HttpClient, LoggerFactory); + + // The test-refresh-client should get an expired token first, + // then automatically refresh it to get a working token + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.True(_testOAuthServer.HasIssuedRefreshToken); + } + + [Fact] + public async Task CanAuthenticate_WithExtraParams() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + AdditionalAuthorizationParameters = new Dictionary + { + ["custom_param"] = "custom_value", + } + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(_lastAuthorizationUri?.Query); + Assert.Contains("custom_param=custom_value", _lastAuthorizationUri?.Query); + } + + [Fact] + public async Task CannotOverrideExistingParameters_WithExtraParams() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + + await using var app = Builder.Build(); + + app.MapMcp().RequireAuthorization(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new SseClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new() + { + ClientId = "demo-client", + ClientSecret = "demo-secret", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + AdditionalAuthorizationParameters = new Dictionary + { + ["redirect_uri"] = "custom_value", + } + }, + }, HttpClient, LoggerFactory); + + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public void CloneResourceMetadataClonesAllProperties() + { + var propertyNames = typeof(ProtectedResourceMetadata).GetProperties().Select(property => property.Name).ToList(); + + // Set metadata properties to non-default values to verify they're copied. + var metadata = new ProtectedResourceMetadata + { + Resource = new Uri("https://example.com/resource"), + AuthorizationServers = [new Uri("https://auth1.example.com/"), new Uri("https://auth2.example.com/")], + BearerMethodsSupported = ["header", "body", "query"], + ScopesSupported = ["read", "write", "admin"], + JwksUri = new Uri("https://example.com/.well-known/jwks.json"), + ResourceSigningAlgValuesSupported = ["RS256", "ES256"], + ResourceName = "Test Resource", + ResourceDocumentation = new Uri("https://docs.example.com/"), + ResourcePolicyUri = new Uri("https://example.com/policy"), + ResourceTosUri = new Uri("https://example.com/terms"), + TlsClientCertificateBoundAccessTokens = true, + AuthorizationDetailsTypesSupported = ["payment_initiation", "account_information"], + DpopSigningAlgValuesSupported = ["RS256", "PS256"], + DpopBoundAccessTokensRequired = true + }; + + // Use reflection to call the internal CloneResourceMetadata method + var handlerType = typeof(McpAuthenticationHandler); + var cloneMethod = handlerType.GetMethod("CloneResourceMetadata", BindingFlags.Static | BindingFlags.NonPublic); + Assert.NotNull(cloneMethod); + + var clonedMetadata = (ProtectedResourceMetadata?)cloneMethod.Invoke(null, [metadata]); + Assert.NotNull(clonedMetadata); + + // Ensure the cloned metadata is not the same instance + Assert.NotSame(metadata, clonedMetadata); + + // Verify Resource property + Assert.Equal(metadata.Resource, clonedMetadata.Resource); + Assert.True(propertyNames.Remove(nameof(metadata.Resource))); + + // Verify AuthorizationServers list is cloned and contains the same values + Assert.NotSame(metadata.AuthorizationServers, clonedMetadata.AuthorizationServers); + Assert.Equal(metadata.AuthorizationServers, clonedMetadata.AuthorizationServers); + Assert.True(propertyNames.Remove(nameof(metadata.AuthorizationServers))); + + // Verify BearerMethodsSupported list is cloned and contains the same values + Assert.NotSame(metadata.BearerMethodsSupported, clonedMetadata.BearerMethodsSupported); + Assert.Equal(metadata.BearerMethodsSupported, clonedMetadata.BearerMethodsSupported); + Assert.True(propertyNames.Remove(nameof(metadata.BearerMethodsSupported))); + + // Verify ScopesSupported list is cloned and contains the same values + Assert.NotSame(metadata.ScopesSupported, clonedMetadata.ScopesSupported); + Assert.Equal(metadata.ScopesSupported, clonedMetadata.ScopesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.ScopesSupported))); + + // Verify JwksUri property + Assert.Equal(metadata.JwksUri, clonedMetadata.JwksUri); + Assert.True(propertyNames.Remove(nameof(metadata.JwksUri))); + + // Verify ResourceSigningAlgValuesSupported list is cloned (nullable list) + Assert.NotSame(metadata.ResourceSigningAlgValuesSupported, clonedMetadata.ResourceSigningAlgValuesSupported); + Assert.Equal(metadata.ResourceSigningAlgValuesSupported, clonedMetadata.ResourceSigningAlgValuesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceSigningAlgValuesSupported))); + + // Verify ResourceName property + Assert.Equal(metadata.ResourceName, clonedMetadata.ResourceName); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceName))); + + // Verify ResourceDocumentation property + Assert.Equal(metadata.ResourceDocumentation, clonedMetadata.ResourceDocumentation); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceDocumentation))); + + // Verify ResourcePolicyUri property + Assert.Equal(metadata.ResourcePolicyUri, clonedMetadata.ResourcePolicyUri); + Assert.True(propertyNames.Remove(nameof(metadata.ResourcePolicyUri))); + + // Verify ResourceTosUri property + Assert.Equal(metadata.ResourceTosUri, clonedMetadata.ResourceTosUri); + Assert.True(propertyNames.Remove(nameof(metadata.ResourceTosUri))); + + // Verify TlsClientCertificateBoundAccessTokens property + Assert.Equal(metadata.TlsClientCertificateBoundAccessTokens, clonedMetadata.TlsClientCertificateBoundAccessTokens); + Assert.True(propertyNames.Remove(nameof(metadata.TlsClientCertificateBoundAccessTokens))); + + // Verify AuthorizationDetailsTypesSupported list is cloned (nullable list) + Assert.NotSame(metadata.AuthorizationDetailsTypesSupported, clonedMetadata.AuthorizationDetailsTypesSupported); + Assert.Equal(metadata.AuthorizationDetailsTypesSupported, clonedMetadata.AuthorizationDetailsTypesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.AuthorizationDetailsTypesSupported))); + + // Verify DpopSigningAlgValuesSupported list is cloned (nullable list) + Assert.NotSame(metadata.DpopSigningAlgValuesSupported, clonedMetadata.DpopSigningAlgValuesSupported); + Assert.Equal(metadata.DpopSigningAlgValuesSupported, clonedMetadata.DpopSigningAlgValuesSupported); + Assert.True(propertyNames.Remove(nameof(metadata.DpopSigningAlgValuesSupported))); + + // Verify DpopBoundAccessTokensRequired property + Assert.Equal(metadata.DpopBoundAccessTokensRequired, clonedMetadata.DpopBoundAccessTokensRequired); + Assert.True(propertyNames.Remove(nameof(metadata.DpopBoundAccessTokensRequired))); + + // Ensure we've checked every property. When new properties get added, we'll have to update this test along with the CloneResourceMetadata implementation. + Assert.Empty(propertyNames); + } + + private async Task HandleAuthorizationUrlAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken) + { + _lastAuthorizationUri = authorizationUri; + + var redirectResponse = await HttpClient.GetAsync(authorizationUri, cancellationToken); + Assert.Equal(HttpStatusCode.Redirect, redirectResponse.StatusCode); + var location = redirectResponse.Headers.Location; + + if (location is not null && !string.IsNullOrEmpty(location.Query)) + { + var queryParams = QueryHelpers.ParseQuery(location.Query); + return queryParams["code"]; + } + + return null; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 394fa4979..9b3c91b94 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -94,7 +94,7 @@ public async Task CallTool_Sse_EchoServer() // assert Assert.NotNull(result); - Assert.False(result.IsError); + Assert.Null(result.IsError); var textContent = Assert.Single(result.Content.OfType()); Assert.Equal("Echo: Hello MCP!", textContent.Text); } @@ -115,10 +115,10 @@ public async Task CallTool_EchoSessionId_ReturnsTheSameSessionId() Assert.NotNull(result2); Assert.NotNull(result3); - Assert.False(result1.IsError); - Assert.False(result2.IsError); - Assert.False(result3.IsError); - + Assert.Null(result1.IsError); + Assert.Null(result2.IsError); + Assert.Null(result3.IsError); + var textContent1 = Assert.Single(result1.Content.OfType()); var textContent2 = Assert.Single(result2.Content.OfType()); var textContent3 = Assert.Single(result3.Content.OfType()); @@ -267,10 +267,10 @@ public async Task Sampling_Sse_TestServer() // Call the server's sampleLLM tool which should trigger our sampling handler var result = await client.CallToolAsync("sampleLLM", new Dictionary - { - ["prompt"] = "Test prompt", - ["maxTokens"] = 100 - }, + { + ["prompt"] = "Test prompt", + ["maxTokens"] = 100 + }, cancellationToken: TestContext.Current.CancellationToken); // assert @@ -288,7 +288,7 @@ public async Task CallTool_Sse_EchoServer_Concurrently() for (int i = 0; i < 4; i++) { var client = (i % 2 == 0) ? client1 : client2; - var result = await client.CallToolAsync( + var result = await client.CallToolAsync( "echo", new Dictionary { @@ -298,7 +298,7 @@ public async Task CallTool_Sse_EchoServer_Concurrently() ); Assert.NotNull(result); - Assert.False(result.IsError); + Assert.Null(result.IsError); var textContent = Assert.Single(result.Content.OfType()); Assert.Equal($"Echo: Hello MCP! {i}", textContent.Text); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index 0a4238500..f31621307 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -20,7 +20,7 @@ public async Task Allows_Customizing_Route(string pattern) await app.StartAsync(TestContext.Current.CancellationToken); - using var response = await HttpClient.GetAsync($"http://localhost{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + using var response = await HttpClient.GetAsync($"http://localhost:5000{pattern}/sse", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); response.EnsureSuccessStatusCode(); using var sseStream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); using var sseStreamReader = new StreamReader(sseStream, System.Text.Encoding.UTF8); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index f14cc10a6..cb1f86db9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -56,7 +56,7 @@ public async Task StreamableHttpMode_Works_WithRootEndpoint() await using var mcpClient = await ConnectAsync("/", new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), TransportMode = HttpTransportMode.AutoDetect }); @@ -82,7 +82,7 @@ public async Task AutoDetectMode_Works_WithRootEndpoint() await using var mcpClient = await ConnectAsync("/", new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), TransportMode = HttpTransportMode.AutoDetect }); @@ -110,7 +110,7 @@ public async Task AutoDetectMode_Works_WithSseEndpoint() await using var mcpClient = await ConnectAsync("/sse", new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), TransportMode = HttpTransportMode.AutoDetect }); @@ -138,7 +138,7 @@ public async Task SseMode_Works_WithSseEndpoint() await using var mcpClient = await ConnectAsync(transportOptions: new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), TransportMode = HttpTransportMode.Sse }); @@ -171,14 +171,16 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia await app.StartAsync(TestContext.Current.CancellationToken); - await using var mcpClient = await ConnectAsync(clientOptions: new() + await using (var mcpClient = await ConnectAsync(clientOptions: new() { ProtocolVersion = "2025-03-26", - }); - await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + })) + { + await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + } - // The header should be included in the GET request, the initialized notification, and the tools/list call. - Assert.Equal(3, protocolVersionHeaderValues.Count); + // The header should be included in the GET request, the initialized notification, the tools/list call, and the delete request. + Assert.NotEmpty(protocolVersionHeaderValues); Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v)); } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 2690352f1..cf54e7774 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -31,9 +31,9 @@ protected async Task ConnectAsync( // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; - await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions() + await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions { - Endpoint = new Uri($"http://localhost{path}/"), + Endpoint = new Uri($"http://localhost:5000{path}"), TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); @@ -80,7 +80,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT await using var mcpClient = await ConnectAsync(); var response = await mcpClient.CallToolAsync( - "EchoWithUserName", + "echo_with_user_name", new Dictionary() { ["message"] = "Hello world!" }, cancellationToken: TestContext.Current.CancellationToken); @@ -171,7 +171,7 @@ public async Task Sampling_DoesNotCloseStream_Prematurely() }, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); - Assert.False(result.IsError); + Assert.Null(result.IsError); var textContent = Assert.Single(result.Content); Assert.Equal("text", textContent.Type); Assert.Equal("Sampling completed successfully. Client responded: Sampling response from client", Assert.IsType(textContent).Text); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj b/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj index bbcac5f53..34801c736 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ModelContextProtocol.AspNetCore.Tests.csproj @@ -34,6 +34,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all + @@ -56,6 +57,7 @@ + diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 4537d16bf..8191f6091 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -17,7 +17,7 @@ public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : Kestr { private readonly SseClientTransportOptions DefaultTransportOptions = new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", }; @@ -149,11 +149,11 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes() var tools = await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(2, tools.Count); - Assert.Contains(tools, tools => tools.Name == "Echo"); + Assert.Contains(tools, tools => tools.Name == "echo"); Assert.Contains(tools, tools => tools.Name == "sampleLLM"); var echoResponse = await mcpClient.CallToolAsync( - "Echo", + "echo", new Dictionary { ["message"] = "from client!" @@ -195,9 +195,9 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - var sseOptions = new SseClientTransportOptions() + var sseOptions = new SseClientTransportOptions { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", AdditionalHeaders = new Dictionary { @@ -222,9 +222,9 @@ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - var sseOptions = new SseClientTransportOptions() + var sseOptions = new SseClientTransportOptions { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", AdditionalHeaders = new Dictionary() { @@ -251,7 +251,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, b response.Headers.ContentType = "text/event-stream"; - await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost/message"); + await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost:5000/message"); session = transport; try diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 1eee02032..2aa675c84 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Tests.Utils; using ModelContextProtocol.TestSseServer; +using System.Net; namespace ModelContextProtocol.AspNetCore.Tests; @@ -19,23 +20,23 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable private SseClientTransportOptions DefaultTransportOptions { get; set; } = new() { - Endpoint = new("http://localhost/"), + Endpoint = new("http://localhost:5000/"), }; public SseServerIntegrationTestFixture() { - var socketsHttpHandler = new SocketsHttpHandler() + var socketsHttpHandler = new SocketsHttpHandler { ConnectCallback = (context, token) => { - var connection = _inMemoryTransport.CreateConnection(); + var connection = _inMemoryTransport.CreateConnection(new DnsEndPoint("localhost", 5000)); return new(connection.ClientStream); }, }; HttpClient = new HttpClient(socketsHttpHandler) { - BaseAddress = new("http://localhost/"), + BaseAddress = new("http://localhost:5000/"), }; _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index eb89912ac..2d4a78685 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -10,7 +10,7 @@ public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, { protected override SseClientTransportOptions ClientTransportOptions => new() { - Endpoint = new Uri("http://localhost/sse"), + Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index a9e2e5f54..d16e510cc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -7,7 +7,7 @@ public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fix { protected override SseClientTransportOptions ClientTransportOptions => new() { - Endpoint = new Uri("http://localhost/stateless"), + Endpoint = new("http://localhost:5000/stateless"), Name = "In-memory Streamable HTTP Client", TransportMode = HttpTransportMode.StreamableHttp, }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index 1e21eb45d..b50a43edc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -16,7 +16,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem private readonly SseClientTransportOptions DefaultTransportOptions = new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", TransportMode = HttpTransportMode.StreamableHttp, }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 119659aec..7ce3516ef 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -85,7 +85,7 @@ private async Task StartAsync(bool enableDelete = false) return Results.Json(new JsonRpcResponse { Id = request.Id, - Result = JsonSerializer.SerializeToNode(new CallToolResult() + Result = JsonSerializer.SerializeToNode(new CallToolResult { Content = [new TextContentBlock { Text = parameters.Arguments["message"].ToString() }], }, McpJsonUtilities.DefaultOptions), @@ -114,7 +114,7 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() await using var transport = new SseClientTransport(new() { - Endpoint = new("http://localhost/mcp"), + Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); @@ -134,7 +134,7 @@ public async Task CanCallToolConcurrently() await using var transport = new SseClientTransport(new() { - Endpoint = new("http://localhost/mcp"), + Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); @@ -160,7 +160,7 @@ public async Task SendsDeleteRequestOnDispose() await using var transport = new SseClientTransport(new() { - Endpoint = new("http://localhost/mcp"), + Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 7c4366f16..3524c60a4 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -13,7 +13,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur protected override SseClientTransportOptions ClientTransportOptions => new() { - Endpoint = new Uri("http://localhost/"), + Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", TransportMode = HttpTransportMode.StreamableHttp, }; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index 45aaeb5b1..4ae743f72 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -8,8 +8,6 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils; public class KestrelInMemoryTest : LoggedTest { - private readonly KestrelInMemoryTransport _inMemoryTransport = new(); - public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { @@ -17,17 +15,16 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) // or a helper that does the same every test. But clear out the existing socket transport to avoid potential port conflicts. Builder = WebApplication.CreateSlimBuilder(); Builder.Services.RemoveAll(); - Builder.Services.AddSingleton(_inMemoryTransport); + Builder.Services.AddSingleton(KestrelInMemoryTransport); Builder.Services.AddSingleton(XunitLoggerProvider); - HttpClient = new HttpClient(new SocketsHttpHandler() + SocketsHttpHandler.ConnectCallback = (context, token) => { - ConnectCallback = (context, token) => - { - var connection = _inMemoryTransport.CreateConnection(); - return new(connection.ClientStream); - }, - }) + var connection = KestrelInMemoryTransport.CreateConnection(context.DnsEndPoint); + return new(connection.ClientStream); + }; + + HttpClient = new HttpClient(SocketsHttpHandler) { BaseAddress = new Uri("http://localhost:5000/"), Timeout = TimeSpan.FromSeconds(10), @@ -38,6 +35,10 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) public HttpClient HttpClient { get; } + public SocketsHttpHandler SocketsHttpHandler { get; } = new(); + + public KestrelInMemoryTransport KestrelInMemoryTransport { get; } = new(); + public override void Dispose() { HttpClient.Dispose(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs index 399e9a833..71809ad6c 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs @@ -1,50 +1,59 @@ using Microsoft.AspNetCore.Connections; +using System.Collections.Concurrent; using System.Net; using System.Threading.Channels; namespace ModelContextProtocol.AspNetCore.Tests.Utils; -public sealed class KestrelInMemoryTransport : IConnectionListenerFactory, IConnectionListener +public sealed class KestrelInMemoryTransport : IConnectionListenerFactory { - private readonly Channel _acceptQueue = Channel.CreateUnbounded(); - private EndPoint? _endPoint; + // socket accept queues keyed by listen port. + private readonly ConcurrentDictionary> _acceptQueues = []; - public EndPoint EndPoint => _endPoint ?? throw new InvalidOperationException("EndPoint is not set. Call BindAsync first."); - - public KestrelInMemoryConnection CreateConnection() + public KestrelInMemoryConnection CreateConnection(EndPoint endpoint) { var connection = new KestrelInMemoryConnection(); - _acceptQueue.Writer.TryWrite(connection); + GetAcceptQueue(endpoint).Writer.TryWrite(connection); return connection; } - public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) + public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) => + new(new KestrelInMemoryListener(endpoint, GetAcceptQueue(endpoint))); + + private Channel GetAcceptQueue(EndPoint endpoint) => + _acceptQueues.GetOrAdd(GetEndpointPort(endpoint), _ => Channel.CreateUnbounded()); + + private static int GetEndpointPort(EndPoint endpoint) => + endpoint switch + { + DnsEndPoint dnsEndpoint => dnsEndpoint.Port, + IPEndPoint ipEndpoint => ipEndpoint.Port, + _ => throw new InvalidOperationException($"Unexpected endpoint type: '{endpoint.GetType()}'"), + }; + + private sealed class KestrelInMemoryListener(EndPoint endpoint, Channel acceptQueue) : IConnectionListener { - if (await _acceptQueue.Reader.WaitToReadAsync(cancellationToken)) + public EndPoint EndPoint => endpoint; + + public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) { - while (_acceptQueue.Reader.TryRead(out var item)) + if (await acceptQueue.Reader.WaitToReadAsync(cancellationToken)) { - return item; + while (acceptQueue.Reader.TryRead(out var item)) + { + return item; + } } - } - - return null; - } - public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) - { - _endPoint = endpoint; - return new ValueTask(this); - } + return null; + } - public ValueTask DisposeAsync() - { - return UnbindAsync(default); - } + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + { + acceptQueue.Writer.TryComplete(); + return default; + } - public ValueTask UnbindAsync(CancellationToken cancellationToken = default) - { - _acceptQueue.Writer.TryComplete(); - return default; + public ValueTask DisposeAsync() => UnbindAsync(CancellationToken.None); } } diff --git a/tests/ModelContextProtocol.TestOAuthServer/AuthorizationCodeInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationCodeInfo.cs new file mode 100644 index 000000000..9d7142ce5 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationCodeInfo.cs @@ -0,0 +1,34 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents authorization code information for OAuth flow. +/// +internal sealed class AuthorizationCodeInfo +{ + /// + /// Gets or sets the client ID associated with this authorization code. + /// + public required string ClientId { get; init; } + + /// + /// Gets or sets the redirect URI associated with this authorization code. + /// + public required string RedirectUri { get; init; } + + /// + /// Gets or sets the code challenge associated with this authorization code (for PKCE). + /// + public required string CodeChallenge { get; init; } + + /// + /// Gets or sets the list of scopes approved for this authorization code. + /// + public List Scope { get; init; } = []; + + /// + /// Gets or sets the optional resource URI this authorization code is for. + /// + public Uri? Resource { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/AuthorizationServerMetadata.cs b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationServerMetadata.cs new file mode 100644 index 000000000..32472a883 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/AuthorizationServerMetadata.cs @@ -0,0 +1,63 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the authorization server metadata for OAuth discovery. +/// +internal sealed class AuthorizationServerMetadata +{ + /// + /// Gets or sets the issuer URL. + /// + [JsonPropertyName("issuer")] + public required Uri Issuer { get; init; } + + /// + /// Gets or sets the authorization endpoint URL. + /// + [JsonPropertyName("authorization_endpoint")] + public required Uri AuthorizationEndpoint { get; init; } + + /// + /// Gets or sets the token endpoint URL. + /// + [JsonPropertyName("token_endpoint")] + public required Uri TokenEndpoint { get; init; } + + /// + /// Gets the introspection endpoint URL. + /// + [JsonPropertyName("introspection_endpoint")] + public Uri? IntrospectionEndpoint => new Uri($"{Issuer}/introspect"); + + /// + /// Gets or sets the response types supported by this server. + /// + [JsonPropertyName("response_types_supported")] + public required List ResponseTypesSupported { get; init; } + + /// + /// Gets or sets the grant types supported by this server. + /// + [JsonPropertyName("grant_types_supported")] + public required List GrantTypesSupported { get; init; } + + /// + /// Gets or sets the token endpoint authentication methods supported by this server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public required List TokenEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the code challenge methods supported by this server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public required List CodeChallengeMethodsSupported { get; init; } + + /// + /// Gets or sets the scopes supported by this server. + /// + [JsonPropertyName("scopes_supported")] + public List? ScopesSupported { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs new file mode 100644 index 000000000..7983476fa --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs @@ -0,0 +1,24 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents client information for OAuth flow. +/// +internal sealed class ClientInfo +{ + /// + /// Gets or sets the client ID. + /// + public required string ClientId { get; init; } + + /// + /// Gets or sets the client secret. + /// + public required string ClientSecret { get; init; } + + /// + /// Gets or sets the list of redirect URIs allowed for this client. + /// + public List RedirectUris { get; init; } = []; +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationRequest.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationRequest.cs new file mode 100644 index 000000000..50592bbea --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationRequest.cs @@ -0,0 +1,93 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a client registration request as defined in RFC 7591. +/// +internal sealed class ClientRegistrationRequest +{ + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public required List RedirectUris { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + public List? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + public List? ResponseTypes { get; init; } + + /// + /// Gets or sets the human-readable name of the client. + /// + [JsonPropertyName("client_name")] + public string? ClientName { get; init; } + + /// + /// Gets or sets the URL of the client's home page. + /// + [JsonPropertyName("client_uri")] + public string? ClientUri { get; init; } + + /// + /// Gets or sets the URL for the client's logo. + /// + [JsonPropertyName("logo_uri")] + public string? LogoUri { get; init; } + + /// + /// Gets or sets the scope values that the client will use. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } + + /// + /// Gets or sets the contacts for the client. + /// + [JsonPropertyName("contacts")] + public List? Contacts { get; init; } + + /// + /// Gets or sets the URL for the client's terms of service. + /// + [JsonPropertyName("tos_uri")] + public string? TosUri { get; init; } + + /// + /// Gets or sets the URL for the client's privacy policy. + /// + [JsonPropertyName("policy_uri")] + public string? PolicyUri { get; init; } + + /// + /// Gets or sets the JWK Set URL for the client. + /// + [JsonPropertyName("jwks_uri")] + public string? JwksUri { get; init; } + + /// + /// Gets or sets the software identifier for the client. + /// + [JsonPropertyName("software_id")] + public string? SoftwareId { get; init; } + + /// + /// Gets or sets the software version for the client. + /// + [JsonPropertyName("software_version")] + public string? SoftwareVersion { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationResponse.cs new file mode 100644 index 000000000..3833c4908 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientRegistrationResponse.cs @@ -0,0 +1,147 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a client registration response as defined in RFC 7591. +/// +internal sealed class ClientRegistrationResponse +{ + /// + /// Gets or sets the client identifier. + /// + [JsonPropertyName("client_id")] + public required string ClientId { get; init; } + + /// + /// Gets or sets the client secret. + /// + [JsonPropertyName("client_secret")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientSecret { get; init; } + + /// + /// Gets or sets the redirect URIs for the client. + /// + [JsonPropertyName("redirect_uris")] + public required List RedirectUris { get; init; } + + /// + /// Gets or sets the registration access token. + /// + [JsonPropertyName("registration_access_token")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationAccessToken { get; init; } + + /// + /// Gets or sets the registration client URI. + /// + [JsonPropertyName("registration_client_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationClientUri { get; init; } + + /// + /// Gets or sets the client ID issued timestamp. + /// + [JsonPropertyName("client_id_issued_at")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public long? ClientIdIssuedAt { get; init; } + + /// + /// Gets or sets the client secret expiration time. + /// + [JsonPropertyName("client_secret_expires_at")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public long? ClientSecretExpiresAt { get; init; } + + /// + /// Gets or sets the token endpoint authentication method. + /// + [JsonPropertyName("token_endpoint_auth_method")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? TokenEndpointAuthMethod { get; init; } + + /// + /// Gets or sets the grant types that the client will use. + /// + [JsonPropertyName("grant_types")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? GrantTypes { get; init; } + + /// + /// Gets or sets the response types that the client will use. + /// + [JsonPropertyName("response_types")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ResponseTypes { get; init; } + + /// + /// Gets or sets the human-readable name of the client. + /// + [JsonPropertyName("client_name")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientName { get; init; } + + /// + /// Gets or sets the URL of the client's home page. + /// + [JsonPropertyName("client_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientUri { get; init; } + + /// + /// Gets or sets the URL for the client's logo. + /// + [JsonPropertyName("logo_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? LogoUri { get; init; } + + /// + /// Gets or sets the scope values that the client will use. + /// + [JsonPropertyName("scope")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Scope { get; init; } + + /// + /// Gets or sets the contacts for the client. + /// + [JsonPropertyName("contacts")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? Contacts { get; init; } + + /// + /// Gets or sets the URL for the client's terms of service. + /// + [JsonPropertyName("tos_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? TosUri { get; init; } + + /// + /// Gets or sets the URL for the client's privacy policy. + /// + [JsonPropertyName("policy_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? PolicyUri { get; init; } + + /// + /// Gets or sets the JWK Set URL for the client. + /// + [JsonPropertyName("jwks_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? JwksUri { get; init; } + + /// + /// Gets or sets the software identifier for the client. + /// + [JsonPropertyName("software_id")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? SoftwareId { get; init; } + + /// + /// Gets or sets the software version for the client. + /// + [JsonPropertyName("software_version")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? SoftwareVersion { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/JsonWebKey.cs b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKey.cs new file mode 100644 index 000000000..562efa526 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKey.cs @@ -0,0 +1,45 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a JSON Web Key. +/// +internal sealed class JsonWebKey +{ + /// + /// Gets or sets the key type (e.g., "RSA"). + /// + [JsonPropertyName("kty")] + public required string KeyType { get; init; } + + /// + /// Gets or sets the intended use of the key (e.g., "sig" for signature). + /// + [JsonPropertyName("use")] + public required string Use { get; init; } + + /// + /// Gets or sets the key ID. + /// + [JsonPropertyName("kid")] + public required string KeyId { get; init; } + + /// + /// Gets or sets the algorithm intended for use with the key (e.g., "RS256"). + /// + [JsonPropertyName("alg")] + public required string Algorithm { get; init; } + + /// + /// Gets or sets the RSA exponent (base64url-encoded). + /// + [JsonPropertyName("e")] + public required string Exponent { get; init; } + + /// + /// Gets or sets the RSA modulus (base64url-encoded). + /// + [JsonPropertyName("n")] + public required string Modulus { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/JsonWebKeySet.cs b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKeySet.cs new file mode 100644 index 000000000..223407b7a --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/JsonWebKeySet.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents a JSON Web Key Set (JWKS) response. +/// +internal sealed class JsonWebKeySet +{ + /// + /// Gets or sets the array of JSON Web Keys. + /// + [JsonPropertyName("keys")] + public required JsonWebKey[] Keys { get; init; } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/ModelContextProtocol.TestOAuthServer.csproj b/tests/ModelContextProtocol.TestOAuthServer/ModelContextProtocol.TestOAuthServer.csproj new file mode 100644 index 000000000..51092f564 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/ModelContextProtocol.TestOAuthServer.csproj @@ -0,0 +1,9 @@ + + + + net9.0;net8.0 + enable + enable + + + diff --git a/tests/ModelContextProtocol.TestOAuthServer/OAuthErrorResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/OAuthErrorResponse.cs new file mode 100644 index 000000000..c9174fa39 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/OAuthErrorResponse.cs @@ -0,0 +1,21 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents an OAuth error response. +/// +internal sealed class OAuthErrorResponse +{ + /// + /// Gets or sets the error code. + /// + [JsonPropertyName("error")] + public required string Error { get; init; } + + /// + /// Gets or sets the error description. + /// + [JsonPropertyName("error_description")] + public required string ErrorDescription { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/OAuthJsonContext.cs b/tests/ModelContextProtocol.TestOAuthServer/OAuthJsonContext.cs new file mode 100644 index 000000000..6caaaea01 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/OAuthJsonContext.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +[JsonSerializable(typeof(OAuthServerMetadata))] +[JsonSerializable(typeof(AuthorizationServerMetadata))] +[JsonSerializable(typeof(TokenResponse))] +[JsonSerializable(typeof(JsonWebKeySet))] +[JsonSerializable(typeof(JsonWebKey))] +[JsonSerializable(typeof(TokenIntrospectionResponse))] +[JsonSerializable(typeof(OAuthErrorResponse))] +[JsonSerializable(typeof(ClientRegistrationRequest))] +[JsonSerializable(typeof(ClientRegistrationResponse))] +[JsonSerializable(typeof(Dictionary))] +internal sealed partial class OAuthJsonContext : JsonSerializerContext; diff --git a/tests/ModelContextProtocol.TestOAuthServer/OAuthServerMetadata.cs b/tests/ModelContextProtocol.TestOAuthServer/OAuthServerMetadata.cs new file mode 100644 index 000000000..646a39929 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/OAuthServerMetadata.cs @@ -0,0 +1,174 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the OAuth 2.0 Authorization Server Metadata as defined in RFC 8414. +/// +internal sealed class OAuthServerMetadata +{ + /// + /// Gets or sets the issuer URL. + /// REQUIRED. The authorization server's issuer identifier, which is a URL that uses the "https" scheme and has no query or fragment components. + /// + [JsonPropertyName("issuer")] + public required string Issuer { get; init; } + + /// + /// Gets or sets the authorization endpoint URL. + /// URL of the authorization server's authorization endpoint. This is REQUIRED unless no grant types are supported that use the authorization endpoint. + /// + [JsonPropertyName("authorization_endpoint")] + public required string AuthorizationEndpoint { get; init; } + + /// + /// Gets or sets the token endpoint URL. + /// URL of the authorization server's token endpoint. This is REQUIRED unless only the implicit grant type is supported. + /// + [JsonPropertyName("token_endpoint")] + public required string TokenEndpoint { get; init; } + + /// + /// Gets or sets the JWKS URI. + /// OPTIONAL. URL of the authorization server's JWK Set document. + /// + [JsonPropertyName("jwks_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? JwksUri { get; init; } + + /// + /// Gets or sets the registration endpoint URL for dynamic client registration. + /// OPTIONAL. URL of the authorization server's OAuth 2.0 Dynamic Client Registration endpoint. + /// + [JsonPropertyName("registration_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationEndpoint { get; init; } + + /// + /// Gets or sets the scopes supported by this server. + /// RECOMMENDED. JSON array containing a list of the OAuth 2.0 scope values that this server supports. + /// + [JsonPropertyName("scopes_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ScopesSupported { get; init; } + + /// + /// Gets or sets the response types supported by this server. + /// RECOMMENDED. JSON array containing a list of the OAuth 2.0 "response_type" values that this server supports. + /// + [JsonPropertyName("response_types_supported")] + public required List ResponseTypesSupported { get; init; } + + /// + /// Gets or sets the response modes supported by this server. + /// OPTIONAL. JSON array containing a list of the OAuth 2.0 "response_mode" values that this server supports. + /// + [JsonPropertyName("response_modes_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ResponseModesSupported { get; init; } + + /// + /// Gets or sets the grant types supported by this server. + /// OPTIONAL. JSON array containing a list of the OAuth 2.0 grant type values that this server supports. + /// + [JsonPropertyName("grant_types_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? GrantTypesSupported { get; init; } + + /// + /// Gets or sets the token endpoint authentication methods supported by this server. + /// OPTIONAL. JSON array containing a list of client authentication methods supported by this token endpoint. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? TokenEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the token endpoint authentication signing algorithms supported by this server. + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms supported by the token endpoint. + /// + [JsonPropertyName("token_endpoint_auth_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? TokenEndpointAuthSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the introspection endpoint URL. + /// OPTIONAL. URL of the authorization server's OAuth 2.0 introspection endpoint. + /// + [JsonPropertyName("introspection_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? IntrospectionEndpoint { get; init; } + + /// + /// Gets or sets the introspection endpoint authentication methods supported by this server. + /// OPTIONAL. JSON array containing a list of client authentication methods supported by this introspection endpoint. + /// + [JsonPropertyName("introspection_endpoint_auth_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? IntrospectionEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the introspection endpoint authentication signing algorithms supported by this server. + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms supported by the introspection endpoint. + /// + [JsonPropertyName("introspection_endpoint_auth_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? IntrospectionEndpointAuthSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the revocation endpoint URL. + /// OPTIONAL. URL of the authorization server's OAuth 2.0 revocation endpoint. + /// + [JsonPropertyName("revocation_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RevocationEndpoint { get; init; } + + /// + /// Gets or sets the revocation endpoint authentication methods supported by this server. + /// OPTIONAL. JSON array containing a list of client authentication methods supported by this revocation endpoint. + /// + [JsonPropertyName("revocation_endpoint_auth_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? RevocationEndpointAuthMethodsSupported { get; init; } + + /// + /// Gets or sets the revocation endpoint authentication signing algorithms supported by this server. + /// OPTIONAL. JSON array containing a list of the JWS signing algorithms supported by the revocation endpoint. + /// + [JsonPropertyName("revocation_endpoint_auth_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? RevocationEndpointAuthSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the code challenge methods supported by this server. + /// OPTIONAL. JSON array containing a list of Proof Key for Code Exchange (PKCE) code challenge methods supported by this server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? CodeChallengeMethodsSupported { get; init; } + + // OpenID Connect specific fields that are commonly included in OAuth metadata + /// + /// Gets or sets the subject types supported by this server. + /// REQUIRED for OpenID Connect. JSON array containing a list of the Subject Identifier types that this OP supports. + /// + [JsonPropertyName("subject_types_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? SubjectTypesSupported { get; init; } + + /// + /// Gets or sets the ID token signing algorithms supported by this server. + /// REQUIRED for OpenID Connect. JSON array containing a list of the JWS signing algorithms (alg values) supported by the OP for the ID Token. + /// + [JsonPropertyName("id_token_signing_alg_values_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? IdTokenSigningAlgValuesSupported { get; init; } + + /// + /// Gets or sets the claims supported by this server. + /// RECOMMENDED for OpenID Connect. JSON array containing a list of the Claim Names of the Claims that the OpenID Provider MAY be able to supply values for. + /// + [JsonPropertyName("claims_supported")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? ClaimsSupported { get; init; } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/Program.cs b/tests/ModelContextProtocol.TestOAuthServer/Program.cs new file mode 100644 index 000000000..3970394b6 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/Program.cs @@ -0,0 +1,634 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.WebUtilities; +using System.Collections.Concurrent; +using System.Globalization; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.TestOAuthServer; + +public sealed class Program +{ + private const int _port = 7029; + private static readonly string _url = $"https://localhost:{_port}"; + + // Port 5000 is used by tests and port 7071 is used by the ProtectedMCPServer sample + private static readonly string[] ValidResources = ["http://localhost:5000/", "http://localhost:7071/"]; + + private readonly ConcurrentDictionary _authCodes = new(); + private readonly ConcurrentDictionary _tokens = new(); + private readonly ConcurrentDictionary _clients = new(); + + private readonly RSA _rsa; + private readonly string _keyId; + + private readonly ILoggerProvider? _loggerProvider; + private readonly IConnectionListenerFactory? _kestrelTransport; + + /// + /// Initializes a new instance of the class with logging and transport parameters. + /// + /// Optional logger provider for logging. + /// Optional Kestrel transport for in-memory connections. + public Program(ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null) + { + _rsa = RSA.Create(2048); + _keyId = Guid.NewGuid().ToString(); + _loggerProvider = loggerProvider; + _kestrelTransport = kestrelTransport; + } + + // Track if we've already issued an already-expired token for the CanAuthenticate_WithTokenRefresh test which uses the test-refresh-client registration. + public bool HasIssuedExpiredToken { get; set; } + public bool HasIssuedRefreshToken { get; set; } + + /// + /// Entry point for the application. + /// + /// Command line arguments. + /// A task representing the asynchronous operation. + public static Task Main(string[] args) => new Program().RunServerAsync(args); + + /// + /// Runs the OAuth server with the specified parameters. + /// + /// Command line arguments. + /// Cancellation token to stop the server. + /// A task representing the asynchronous operation. + public async Task RunServerAsync(string[]? args = null, CancellationToken cancellationToken = default) + { + Console.WriteLine("Starting in-memory test-only OAuth Server..."); + + var builder = WebApplication.CreateEmptyBuilder(new() + { + Args = args, + }); + + if (_kestrelTransport is not null) + { + // Add passed-in transport before calling UseKestrel() to avoid the SocketsHttpHandler getting added. + builder.Services.AddSingleton(_kestrelTransport); + } + + builder.WebHost.UseKestrel(kestrelOptions => + { + kestrelOptions.ListenLocalhost(_port, listenOptions => + { + listenOptions.UseHttps(); + }); + }); + + builder.Services.AddRoutingCore(); + builder.Services.AddLogging(); + + builder.Services.ConfigureHttpJsonOptions(jsonOptions => + { + jsonOptions.SerializerOptions.TypeInfoResolverChain.Add(OAuthJsonContext.Default); + }); + + builder.Logging.AddConsole(); + if (_loggerProvider is not null) + { + builder.Logging.AddProvider(_loggerProvider); + } + + var app = builder.Build(); + + app.UseRouting(); + app.UseEndpoints(_ => { }); + + // Set up the demo client + var clientId = "demo-client"; + var clientSecret = "demo-secret"; + _clients[clientId] = new ClientInfo + { + ClientId = clientId, + ClientSecret = clientSecret, + RedirectUris = ["http://localhost:1179/callback"], + }; + + // When this client ID is used, the first token issued will already be expired to make + // testing the refresh flow easier. + _clients["test-refresh-client"] = new ClientInfo + { + ClientId = "test-refresh-client", + ClientSecret = "test-refresh-secret", + RedirectUris = ["http://localhost:1179/callback"], + }; + + // The MCP spec tells the client to use /.well-known/oauth-authorization-server but AddJwtBearer looks for + // /.well-known/openid-configuration by default. To make things easier, we support both with the same response + // which seems to be common. Ex. https://github.com/keycloak/keycloak/pull/29628 + // + // The requirements for these endpoints are at https://www.rfc-editor.org/rfc/rfc8414 and + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata respectively. + // They do differ, but it's close enough at least for our current testing to use the same response for both. + // See https://gist.github.com/localden/26d8bcf641703c08a5d8741aa9c3336c + string[] metadataEndpoints = ["/.well-known/oauth-authorization-server", "/.well-known/openid-configuration"]; + foreach (var metadataEndpoint in metadataEndpoints) + { + // OAuth 2.0 Authorization Server Metadata (RFC 8414) + app.MapGet(metadataEndpoint, () => + { + var metadata = new OAuthServerMetadata + { + Issuer = _url, + AuthorizationEndpoint = $"{_url}/authorize", + TokenEndpoint = $"{_url}/token", + JwksUri = $"{_url}/.well-known/jwks.json", + ResponseTypesSupported = ["code"], + SubjectTypesSupported = ["public"], + IdTokenSigningAlgValuesSupported = ["RS256"], + ScopesSupported = ["openid", "profile", "email", "mcp:tools"], + TokenEndpointAuthMethodsSupported = ["client_secret_post"], + ClaimsSupported = ["sub", "iss", "name", "email", "aud"], + CodeChallengeMethodsSupported = ["S256"], + GrantTypesSupported = ["authorization_code", "refresh_token"], + IntrospectionEndpoint = $"{_url}/introspect", + RegistrationEndpoint = $"{_url}/register" + }; + + return Results.Ok(metadata); + }); + } + + // JWKS endpoint to expose the public key + app.MapGet("/.well-known/jwks.json", () => + { + var parameters = _rsa.ExportParameters(false); + + // Convert parameters to base64url encoding + var e = WebEncoders.Base64UrlEncode(parameters.Exponent ?? Array.Empty()); + var n = WebEncoders.Base64UrlEncode(parameters.Modulus ?? Array.Empty()); + + var jwks = new JsonWebKeySet + { + Keys = [ + new JsonWebKey + { + KeyType = "RSA", + Use = "sig", + KeyId = _keyId, + Algorithm = "RS256", + Exponent = e, + Modulus = n + } + ] + }; + + return Results.Ok(jwks); + }); + + // Authorize endpoint + app.MapGet("/authorize", ( + [FromQuery] string client_id, + [FromQuery] string? redirect_uri, + [FromQuery] string response_type, + [FromQuery] string code_challenge, + [FromQuery] string code_challenge_method, + [FromQuery] string? scope, + [FromQuery] string? state, + [FromQuery] string? resource) => + { + // Validate client + if (!_clients.TryGetValue(client_id, out var client)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_client", + ErrorDescription = "Client not found" + }); + } + + // Validate redirect_uri + if (string.IsNullOrEmpty(redirect_uri)) + { + if (client.RedirectUris.Count == 1) + { + redirect_uri = client.RedirectUris[0]; + } + else + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "redirect_uri is required when client has multiple registered URIs" + }); + } + } + else if (!client.RedirectUris.Contains(redirect_uri)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "Unregistered redirect_uri" + }); + } + + // Validate response_type + if (response_type != "code") + { + return Results.Redirect($"{redirect_uri}?error=unsupported_response_type&error_description=Only+code+response_type+is+supported&state={state}"); + } + + // Validate code challenge method + if (code_challenge_method != "S256") + { + return Results.Redirect($"{redirect_uri}?error=invalid_request&error_description=Only+S256+code_challenge_method+is+supported&state={state}"); + } + + // Validate resource in accordance with RFC 8707 + if (string.IsNullOrEmpty(resource) || !ValidResources.Contains(resource)) + { + return Results.Redirect($"{redirect_uri}?error=invalid_target&error_description=The+specified+resource+is+not+valid&state={state}"); + } + + // Generate a new authorization code + var code = GenerateRandomToken(); + var requestedScopes = scope?.Split(' ').ToList() ?? []; + + // Store code information for later verification + _authCodes[code] = new AuthorizationCodeInfo + { + ClientId = client_id, + RedirectUri = redirect_uri, + CodeChallenge = code_challenge, + Scope = requestedScopes, + Resource = !string.IsNullOrEmpty(resource) ? new Uri(resource) : null + }; + + // Redirect back to client with the code + var redirectUrl = $"{redirect_uri}?code={code}"; + if (!string.IsNullOrEmpty(state)) + { + redirectUrl += $"&state={Uri.EscapeDataString(state)}"; + } + + return Results.Redirect(redirectUrl); + }); + + // Token endpoint + app.MapPost("/token", async (HttpContext context) => + { + var form = await context.Request.ReadFormAsync(); + + // Authenticate client + var client = AuthenticateClient(context, form); + if (client == null) + { + context.Response.StatusCode = 401; + return Results.Problem( + statusCode: 401, + title: "Unauthorized", + detail: "Invalid client credentials", + type: "https://tools.ietf.org/html/rfc6749#section-5.2"); + } + + // Validate resource in accordance with RFC 8707 + var resource = form["resource"].ToString(); + if (string.IsNullOrEmpty(resource) || !ValidResources.Contains(resource)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_target", + ErrorDescription = "The specified resource is not valid." + }); + } + + var grant_type = form["grant_type"].ToString(); + if (grant_type == "authorization_code") + { + var code = form["code"].ToString(); + var code_verifier = form["code_verifier"].ToString(); + var redirect_uri = form["redirect_uri"].ToString(); + + // Validate code + if (string.IsNullOrEmpty(code) || !_authCodes.TryRemove(code, out var codeInfo)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Invalid authorization code" + }); + } + + // Validate client_id + if (codeInfo.ClientId != client.ClientId) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Authorization code was not issued to this client" + }); + } + + // Validate redirect_uri if provided + if (!string.IsNullOrEmpty(redirect_uri) && redirect_uri != codeInfo.RedirectUri) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Redirect URI mismatch" + }); + } + + // Validate code verifier + if (string.IsNullOrEmpty(code_verifier) || !VerifyCodeChallenge(code_verifier, codeInfo.CodeChallenge)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Code verifier does not match the challenge" + }); + } + + // Generate JWT token response + var response = GenerateJwtTokenResponse(client.ClientId, codeInfo.Scope, codeInfo.Resource); + return Results.Ok(response); + } + else if (grant_type == "refresh_token") + { + var refresh_token = form["refresh_token"].ToString(); + + // Validate refresh token + if (string.IsNullOrEmpty(refresh_token) || !_tokens.TryGetValue(refresh_token, out var tokenInfo) || tokenInfo.ClientId != client.ClientId) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_grant", + ErrorDescription = "Invalid refresh token" + }); + } + + // Generate new token response, keeping the same scopes + var response = GenerateJwtTokenResponse(client.ClientId, tokenInfo.Scopes, tokenInfo.Resource); + + // Remove the old refresh token + if (!string.IsNullOrEmpty(refresh_token)) + { + _tokens.TryRemove(refresh_token, out _); + } + + HasIssuedRefreshToken = true; + return Results.Ok(response); + } + else + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "unsupported_grant_type", + ErrorDescription = "Unsupported grant type" + }); + } + }); + + // Introspection endpoint + app.MapPost("/introspect", async (HttpContext context) => + { + var form = await context.Request.ReadFormAsync(); + var token = form["token"].ToString(); + + if (string.IsNullOrEmpty(token)) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "Token is required" + }); + } + + // Check opaque access tokens + if (_tokens.TryGetValue(token, out var tokenInfo)) + { + if (tokenInfo.ExpiresAt < DateTimeOffset.UtcNow) + { + return Results.Ok(new TokenIntrospectionResponse { Active = false }); + } + + return Results.Ok(new TokenIntrospectionResponse + { + Active = true, + ClientId = tokenInfo.ClientId, + Scope = string.Join(" ", tokenInfo.Scopes), + ExpirationTime = tokenInfo.ExpiresAt.ToUnixTimeSeconds(), + Audience = tokenInfo.Resource?.ToString() + }); + } + + return Results.Ok(new TokenIntrospectionResponse { Active = false }); + }); + + // Dynamic Client Registration endpoint (RFC 7591) + app.MapPost("/register", async (HttpContext context) => + { + using var stream = context.Request.Body; + var registrationRequest = await JsonSerializer.DeserializeAsync( + stream, + OAuthJsonContext.Default.ClientRegistrationRequest, + context.RequestAborted); + + if (registrationRequest is null) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_request", + ErrorDescription = "Invalid registration request" + }); + } + + // Validate redirect URIs are provided + if (registrationRequest.RedirectUris.Count == 0) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_redirect_uri", + ErrorDescription = "At least one redirect URI must be provided" + }); + } + + // Validate redirect URIs + foreach (var redirectUri in registrationRequest.RedirectUris) + { + if (!Uri.TryCreate(redirectUri, UriKind.Absolute, out var uri) || + (uri.Scheme != "http" && uri.Scheme != "https")) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "invalid_redirect_uri", + ErrorDescription = $"Invalid redirect URI: {redirectUri}" + }); + } + } + + // Generate client credentials + var clientId = $"dyn-{Guid.NewGuid():N}"; + var clientSecret = GenerateRandomToken(); + var issuedAt = DateTimeOffset.UtcNow; + + // Store the registered client + _clients[clientId] = new ClientInfo + { + ClientId = clientId, + ClientSecret = clientSecret, + RedirectUris = registrationRequest.RedirectUris, + }; + + var registrationResponse = new ClientRegistrationResponse + { + ClientId = clientId, + ClientSecret = clientSecret, + ClientIdIssuedAt = issuedAt.ToUnixTimeSeconds(), + RedirectUris = registrationRequest.RedirectUris, + GrantTypes = ["authorization_code", "refresh_token"], + ResponseTypes = ["code"], + TokenEndpointAuthMethod = "client_secret_post", + }; + + return Results.Ok(registrationResponse); + }); + + app.MapGet("/", () => "Demo In-Memory OAuth 2.0 Server with JWT Support"); + + Console.WriteLine($"OAuth Authorization Server running at {_url}"); + Console.WriteLine($"OAuth Server Metadata at {_url}/.well-known/oauth-authorization-server"); + Console.WriteLine($"JWT keys available at {_url}/.well-known/jwks.json"); + Console.WriteLine($"Demo Client ID: {clientId}"); + Console.WriteLine($"Demo Client Secret: {clientSecret}"); + + await app.RunAsync(cancellationToken); + } + + /// + /// Authenticates a client based on client credentials in the request. + /// + /// The HTTP context. + /// The form collection containing client credentials. + /// The client info if authentication succeeds, null otherwise. + private ClientInfo? AuthenticateClient(HttpContext context, IFormCollection form) + { + var clientId = form["client_id"].ToString(); + var clientSecret = form["client_secret"].ToString(); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(clientSecret)) + { + return null; + } + + if (_clients.TryGetValue(clientId, out var client) && client.ClientSecret == clientSecret) + { + return client; + } + + return null; + } + + /// + /// Generates a JWT token response. + /// + /// The client ID. + /// The approved scopes. + /// The resource URI. + /// A token response. + private TokenResponse GenerateJwtTokenResponse(string clientId, List scopes, Uri? resource) + { + var expiresIn = TimeSpan.FromHours(1); + var issuedAt = DateTimeOffset.UtcNow; + + // For test-refresh-client, make the first token expired to test refresh functionality. + if (clientId == "test-refresh-client" && !HasIssuedExpiredToken) + { + HasIssuedExpiredToken = true; + expiresIn = TimeSpan.FromHours(-1); + } + + var expiresAt = issuedAt.Add(expiresIn); + var jwtId = Guid.NewGuid().ToString(); + + // Create JWT header and payload + var header = new Dictionary + { + { "alg", "RS256" }, + { "typ", "JWT" }, + { "kid", _keyId } + }; + + var payload = new Dictionary + { + { "iss", _url }, + { "sub", $"user-{clientId}" }, + { "name", $"user-{clientId}" }, + { "aud", resource?.ToString() ?? clientId }, + { "client_id", clientId }, + { "jti", jwtId }, + { "iat", issuedAt.ToUnixTimeSeconds().ToString(CultureInfo.InvariantCulture) }, + { "exp", expiresAt.ToUnixTimeSeconds().ToString(CultureInfo.InvariantCulture) }, + { "scope", string.Join(" ", scopes) } + }; + + // Create JWT token + var headerJson = JsonSerializer.Serialize(header, OAuthJsonContext.Default.DictionaryStringString); + var payloadJson = JsonSerializer.Serialize(payload, OAuthJsonContext.Default.DictionaryStringString); + + var headerBase64 = WebEncoders.Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson)); + var payloadBase64 = WebEncoders.Base64UrlEncode(Encoding.UTF8.GetBytes(payloadJson)); + + var dataToSign = $"{headerBase64}.{payloadBase64}"; + var signature = _rsa.SignData(Encoding.UTF8.GetBytes(dataToSign), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var signatureBase64 = WebEncoders.Base64UrlEncode(signature); + + var jwtToken = $"{headerBase64}.{payloadBase64}.{signatureBase64}"; + + // Generate opaque refresh token + var refreshToken = GenerateRandomToken(); + + // Store token info (for refresh token and introspection) + var tokenInfo = new TokenInfo + { + ClientId = clientId, + Scopes = scopes, + IssuedAt = issuedAt, + ExpiresAt = expiresAt, + Resource = resource, + JwtId = jwtId + }; + + _tokens[refreshToken] = tokenInfo; + + return new TokenResponse + { + AccessToken = jwtToken, + RefreshToken = refreshToken, + TokenType = "Bearer", + ExpiresIn = (int)expiresIn.TotalSeconds, + Scope = string.Join(" ", scopes) + }; + } + + /// + /// Generates a random token for authorization code or refresh token. + /// + /// A Base64Url encoded random token. + public static string GenerateRandomToken() + { + var bytes = new byte[32]; + Random.Shared.NextBytes(bytes); + return WebEncoders.Base64UrlEncode(bytes); + } + + /// + /// Verifies a PKCE code challenge against a code verifier. + /// + /// The code verifier to verify. + /// The code challenge to verify against. + /// True if the code challenge is valid, false otherwise. + public static bool VerifyCodeChallenge(string codeVerifier, string codeChallenge) + { + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + var computedChallenge = WebEncoders.Base64UrlEncode(challengeBytes); + + return computedChallenge == codeChallenge; + } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/Properties/launchSettings.json b/tests/ModelContextProtocol.TestOAuthServer/Properties/launchSettings.json new file mode 100644 index 000000000..71b2b21fe --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/Properties/launchSettings.json @@ -0,0 +1,14 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "https://localhost:7029/", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/tests/ModelContextProtocol.TestOAuthServer/TokenInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/TokenInfo.cs new file mode 100644 index 000000000..159ef34eb --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/TokenInfo.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents token information for OAuth flow. +/// +internal sealed class TokenInfo +{ + /// + /// Gets or sets the client ID associated with this token. + /// + public required string ClientId { get; init; } + + /// + /// Gets or sets the list of scopes approved for this token. + /// + public List Scopes { get; init; } = []; + + /// + /// Gets or sets the issued time of this token. + /// + public required DateTimeOffset IssuedAt { get; init; } + + /// + /// Gets or sets the expiration time of this token. + /// + public required DateTimeOffset ExpiresAt { get; init; } + + /// + /// Gets or sets the optional resource URI this token is for. + /// + public Uri? Resource { get; init; } + + /// + /// Gets or sets the JWT ID for this token. + /// + public string? JwtId { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/TokenIntrospectionResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/TokenIntrospectionResponse.cs new file mode 100644 index 000000000..a27b624ab --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/TokenIntrospectionResponse.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the response from the token introspection endpoint. +/// +internal sealed class TokenIntrospectionResponse +{ + /// + /// Gets or sets a value indicating whether the token is active. + /// + [JsonPropertyName("active")] + public required bool Active { get; init; } + + /// + /// Gets or sets the client ID associated with the token. + /// + [JsonPropertyName("client_id")] + public string? ClientId { get; init; } + + /// + /// Gets or sets the scope of the token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } + + /// + /// Gets or sets the expiration timestamp of the token (Unix timestamp). + /// + [JsonPropertyName("exp")] + public long? ExpirationTime { get; init; } + + /// + /// Gets or sets the audience of the token. + /// + [JsonPropertyName("aud")] + public string? Audience { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/TokenResponse.cs b/tests/ModelContextProtocol.TestOAuthServer/TokenResponse.cs new file mode 100644 index 000000000..20789feb1 --- /dev/null +++ b/tests/ModelContextProtocol.TestOAuthServer/TokenResponse.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.TestOAuthServer; + +/// +/// Represents the token response for OAuth flow. +/// +internal sealed class TokenResponse +{ + /// + /// Gets or sets the access token. + /// + [JsonPropertyName("access_token")] + public required string AccessToken { get; init; } + + /// + /// Gets or sets the token type. + /// + [JsonPropertyName("token_type")] + public required string TokenType { get; init; } + + /// + /// Gets or sets the token expiration time in seconds. + /// + [JsonPropertyName("expires_in")] + public required int ExpiresIn { get; init; } + + /// + /// Gets or sets the refresh token. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; init; } + + /// + /// Gets or sets the scope approved for this token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; init; } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 97013f64b..0bc4134fa 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -38,7 +38,7 @@ private static async Task Main(string[] args) McpServerOptions options = new() { - Capabilities = new ServerCapabilities() + Capabilities = new ServerCapabilities { Tools = ConfigureTools(), Resources = ConfigureResources(), @@ -75,7 +75,7 @@ private static async Task RunBackgroundLoop(IMcpServer server, CancellationToken if (_minimumLoggingLevel is not null) { var logLevel = loggingLevels[random.Next(loggingLevels.Length)]; - await server.SendMessageAsync(new JsonRpcNotification() + await server.SendMessageAsync(new JsonRpcNotification { Method = NotificationMethods.LoggingMessageNotification, Params = JsonSerializer.SerializeToNode(new LoggingMessageNotificationParams @@ -111,11 +111,11 @@ private static ToolsCapability ConfigureTools() { ListToolsHandler = async (request, cancellationToken) => { - return new ListToolsResult() + return new ListToolsResult { Tools = [ - new Tool() + new Tool { Name = "echo", Description = "Echoes the input back to the client.", @@ -132,7 +132,7 @@ private static ToolsCapability ConfigureTools() } """), }, - new Tool() + new Tool { Name = "echoSessionId", Description = "Echoes the session id back to the client.", @@ -142,7 +142,7 @@ private static ToolsCapability ConfigureTools() } """, McpJsonUtilities.DefaultOptions), }, - new Tool() + new Tool { Name = "sampleLLM", Description = "Samples from an LLM using MCP's sampling feature.", @@ -175,14 +175,14 @@ private static ToolsCapability ConfigureTools() { throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); } - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = $"Echo: {message}" }] }; } else if (request.Params?.Name == "echoSessionId") { - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] }; @@ -198,7 +198,7 @@ private static ToolsCapability ConfigureTools() var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), cancellationToken); - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] }; @@ -217,27 +217,27 @@ private static PromptsCapability ConfigurePrompts() { ListPromptsHandler = async (request, cancellationToken) => { - return new ListPromptsResult() + return new ListPromptsResult { Prompts = [ - new Prompt() + new Prompt { Name = "simple_prompt", Description = "A prompt without arguments" }, - new Prompt() + new Prompt { Name = "complex_prompt", Description = "A prompt with arguments", Arguments = [ - new PromptArgument() + new PromptArgument { Name = "temperature", Description = "Temperature setting", Required = true }, - new PromptArgument() + new PromptArgument { Name = "style", Description = "Output style", @@ -254,7 +254,7 @@ private static PromptsCapability ConfigurePrompts() List messages = []; if (request.Params?.Name == "simple_prompt") { - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = "This is a simple prompt without arguments." }, @@ -264,20 +264,20 @@ private static PromptsCapability ConfigurePrompts() { string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" }, }); - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.Assistant, Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, }); - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.User, - Content = new ImageContentBlock() + Content = new ImageContentBlock { Data = MCP_TINY_IMAGE, MimeType = "image/png" @@ -289,7 +289,7 @@ private static PromptsCapability ConfigurePrompts() throw new McpException($"Unknown prompt: {request.Params?.Name}", McpErrorCode.InvalidParams); } - return new GetPromptResult() + return new GetPromptResult { Messages = messages }; @@ -328,13 +328,13 @@ private static ResourcesCapability ConfigureResources() string uri = $"test://static/resource/{i + 1}"; if (i % 2 == 0) { - resources.Add(new Resource() + resources.Add(new Resource { Uri = uri, Name = $"Resource {i + 1}", MimeType = "text/plain" }); - resourceContents.Add(new TextResourceContents() + resourceContents.Add(new TextResourceContents { Uri = uri, MimeType = "text/plain", @@ -344,13 +344,13 @@ private static ResourcesCapability ConfigureResources() else { var buffer = Encoding.UTF8.GetBytes($"Resource {i + 1}: This is a base64 blob"); - resources.Add(new Resource() + resources.Add(new Resource { Uri = uri, Name = $"Resource {i + 1}", MimeType = "application/octet-stream" }); - resourceContents.Add(new BlobResourceContents() + resourceContents.Add(new BlobResourceContents { Uri = uri, MimeType = "application/octet-stream", @@ -365,10 +365,10 @@ private static ResourcesCapability ConfigureResources() { ListResourceTemplatesHandler = async (request, cancellationToken) => { - return new ListResourceTemplatesResult() + return new ListResourceTemplatesResult { ResourceTemplates = [ - new ResourceTemplate() + new ResourceTemplate { UriTemplate = "test://dynamic/resource/{id}", Name = "Dynamic Resource", @@ -400,7 +400,7 @@ private static ResourcesCapability ConfigureResources() { nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } - return new ListResourcesResult() + return new ListResourcesResult { NextCursor = nextCursor, Resources = resources.GetRange(startIndex, endIndex - startIndex) @@ -422,10 +422,10 @@ private static ResourcesCapability ConfigureResources() throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } - return new ReadResourceResult() + return new ReadResourceResult { Contents = [ - new TextResourceContents() + new TextResourceContents { Uri = request.Params.Uri, MimeType = "text/plain", @@ -438,7 +438,7 @@ private static ResourcesCapability ConfigureResources() ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - return new ReadResourceResult() + return new ReadResourceResult { Contents = [contents] }; @@ -523,9 +523,9 @@ private static CompletionsCapability ConfigureCompletions() static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) { - return new CreateMessageRequestParams() + return new CreateMessageRequestParams { - Messages = [new SamplingMessage() + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 56e98c983..d4abf81f9 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -28,7 +28,7 @@ private static void ConfigureSerilog(ILoggingBuilder loggingBuilder) private static void ConfigureOptions(McpServerOptions options) { - options.Capabilities = new ServerCapabilities() + options.Capabilities = new ServerCapabilities { Tools = new(), Resources = new(), @@ -41,9 +41,9 @@ private static void ConfigureOptions(McpServerOptions options) #region Helped method static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) { - return new CreateMessageRequestParams() + return new CreateMessageRequestParams { - Messages = [new SamplingMessage() + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = $"Resource {uri} context: {context}" }, @@ -63,13 +63,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st string uri = $"test://static/resource/{i + 1}"; if (i % 2 == 0) { - resources.Add(new Resource() + resources.Add(new Resource { Uri = uri, Name = $"Resource {i + 1}", MimeType = "text/plain" }); - resourceContents.Add(new TextResourceContents() + resourceContents.Add(new TextResourceContents { Uri = uri, MimeType = "text/plain", @@ -79,13 +79,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st else { var buffer = Encoding.UTF8.GetBytes($"Resource {i + 1}: This is a base64 blob"); - resources.Add(new Resource() + resources.Add(new Resource { Uri = uri, Name = $"Resource {i + 1}", MimeType = "application/octet-stream" }); - resourceContents.Add(new BlobResourceContents() + resourceContents.Add(new BlobResourceContents { Uri = uri, MimeType = "application/octet-stream", @@ -102,11 +102,11 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { ListToolsHandler = async (request, cancellationToken) => { - return new ListToolsResult() + return new ListToolsResult { Tools = [ - new Tool() + new Tool { Name = "echo", Description = "Echoes the input back to the client.", @@ -123,7 +123,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } """, McpJsonUtilities.DefaultOptions), }, - new Tool() + new Tool { Name = "echoSessionId", Description = "Echoes the session id back to the client.", @@ -133,7 +133,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } """, McpJsonUtilities.DefaultOptions), }, - new Tool() + new Tool { Name = "sampleLLM", Description = "Samples from an LLM using MCP's sampling feature.", @@ -169,14 +169,14 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); } - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = $"Echo: {message}" }] }; } else if (request.Params.Name == "echoSessionId") { - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] }; @@ -192,7 +192,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st var sampleResult = await request.Server.SampleAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), cancellationToken); - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = $"LLM sampling result: {(sampleResult.Content as TextContentBlock)?.Text}" }] }; @@ -208,10 +208,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st ListResourceTemplatesHandler = async (request, cancellationToken) => { - return new ListResourceTemplatesResult() + return new ListResourceTemplatesResult { ResourceTemplates = [ - new ResourceTemplate() + new ResourceTemplate { UriTemplate = "test://dynamic/resource/{id}", Name = "Dynamic Resource", @@ -245,7 +245,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } - return new ListResourcesResult() + return new ListResourcesResult { NextCursor = nextCursor, Resources = resources.GetRange(startIndex, endIndex - startIndex) @@ -266,10 +266,10 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } - return new ReadResourceResult() + return new ReadResourceResult { Contents = [ - new TextResourceContents() + new TextResourceContents { Uri = request.Params.Uri, MimeType = "text/plain", @@ -282,7 +282,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st ResourceContents? contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); - return new ReadResourceResult() + return new ReadResourceResult { Contents = [contents] }; @@ -292,27 +292,27 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { ListPromptsHandler = async (request, cancellationToken) => { - return new ListPromptsResult() + return new ListPromptsResult { Prompts = [ - new Prompt() + new Prompt { Name = "simple_prompt", Description = "A prompt without arguments" }, - new Prompt() + new Prompt { Name = "complex_prompt", Description = "A prompt with arguments", Arguments = [ - new PromptArgument() + new PromptArgument { Name = "temperature", Description = "Temperature setting", Required = true }, - new PromptArgument() + new PromptArgument { Name = "style", Description = "Output style", @@ -332,7 +332,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st List messages = new(); if (request.Params.Name == "simple_prompt") { - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = "This is a simple prompt without arguments." }, @@ -342,20 +342,20 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" }, }); - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.Assistant, Content = new TextContentBlock { Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" }, }); - messages.Add(new PromptMessage() + messages.Add(new PromptMessage { Role = Role.User, - Content = new ImageContentBlock() + Content = new ImageContentBlock { Data = MCP_TINY_IMAGE, MimeType = "image/png" @@ -367,7 +367,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st throw new McpException($"Unknown prompt: {request.Params.Name}", McpErrorCode.InvalidParams); } - return new GetPromptResult() + return new GetPromptResult { Messages = messages }; diff --git a/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs b/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs new file mode 100644 index 000000000..ec603c63f --- /dev/null +++ b/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs @@ -0,0 +1,28 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Tests; + +public class AIContentExtensionsTests +{ + [Fact] + public void CallToolResult_ToChatMessage_ProducesExpectedAIContent() + { + CallToolResult toolResult = new() { Content = [new TextContentBlock { Text = "This is a test message." }] }; + + Assert.Throws(() => AIContentExtensions.ToChatMessage(null!, "call123")); + Assert.Throws(() => AIContentExtensions.ToChatMessage(toolResult, null!)); + + ChatMessage message = AIContentExtensions.ToChatMessage(toolResult, "call123"); + + Assert.NotNull(message); + Assert.Equal(ChatRole.Tool, message.Role); + + FunctionResultContent frc = Assert.IsType(Assert.Single(message.Contents)); + Assert.Same(toolResult, frc.RawRepresentation); + Assert.Equal("call123", frc.CallId); + JsonElement result = Assert.IsType(frc.Result); + Assert.Contains("This is a test message.", result.ToString()); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs index 7e66fa3db..48c3c370d 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs @@ -15,7 +15,7 @@ public McpClientResourceTemplateTests(ITestOutputHelper outputHelper) : base(out protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) { mcpServerBuilder.WithReadResourceHandler((request, cancellationToken) => - new ValueTask(new ReadResourceResult() + new ValueTask(new ReadResourceResult { Contents = [new TextResourceContents { Text = request.Params?.Uri ?? string.Empty }] })); diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index a8ee21d75..3e4361a57 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -91,7 +91,7 @@ public async Task CallTool_Stdio_EchoServer(string clientId) // assert Assert.NotNull(result); - Assert.False(result.IsError); + Assert.Null(result.IsError); var textContent = Assert.Single(result.Content.OfType()); Assert.Equal("Echo: Hello MCP!", textContent.Text); } @@ -107,7 +107,7 @@ public async Task CallTool_Stdio_EchoSessionId_ReturnsEmpty() // assert Assert.NotNull(result); - Assert.False(result.IsError); + Assert.Null(result.IsError); var textContent = Assert.Single(result.Content.OfType()); Assert.Empty(textContent.Text); } @@ -485,7 +485,7 @@ public async Task CallTool_Stdio_MemoryServer() // assert Assert.NotNull(result); - Assert.False(result.IsError); + Assert.Null(result.IsError); Assert.Single(result.Content, c => c.Type == "text"); await client.DisposeAsync(); diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index 326b235f0..ec1c85107 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -62,13 +62,14 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer() + protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) { return await McpClientFactory.CreateAsync( new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), LoggerFactory), + clientOptions: clientOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 11ee7eaf0..3fa2ec78b 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -69,7 +69,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer case "FirstCustomPrompt": case "SecondCustomPrompt": case "FinalCustomPrompt": - return new GetPromptResult() + return new GetPromptResult { Messages = [new() { Role = Role.User, Content = new TextContentBlock { Text = $"hello from {request.Params.Name}" } }], }; @@ -100,7 +100,7 @@ public async Task Can_List_And_Call_Registered_Prompts() var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); - var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages)); + var prompt = prompts.First(t => t.Name == "returns_chat_messages"); Assert.Equal("Returns chat messages", prompt.Description); var result = await prompt.GetAsync(new Dictionary() { ["message"] = "hello" }, cancellationToken: TestContext.Current.CancellationToken); @@ -171,7 +171,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() Assert.NotNull(prompts); Assert.NotEmpty(prompts); - McpClientPrompt prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsString)); + McpClientPrompt prompt = prompts.First(t => t.Name == "returns_string"); Assert.Equal("This is a title", prompt.Title); } @@ -204,7 +204,7 @@ public async Task Throws_Exception_Missing_Parameter() await using IMcpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( - nameof(SimplePrompts.ReturnsChatMessages), + "returns_chat_messages", cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal(McpErrorCode.InternalError, e.ErrorCode); @@ -242,7 +242,7 @@ public void Register_Prompts_From_Current_Assembly() sc.AddMcpServer().WithPromptsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); - Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages)); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "returns_chat_messages"); } [Fact] @@ -255,10 +255,10 @@ public void Register_Prompts_From_Multiple_Sources() .WithPrompts([McpServerPrompt.Create(() => "42", new() { Name = "Returns42" })]); IServiceProvider services = sc.BuildServiceProvider(); - Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages)); - Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ThrowsException)); - Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsString)); - Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(MorePrompts.AnotherPrompt)); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "returns_chat_messages"); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "throws_exception"); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "returns_string"); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "another_prompt"); Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "Returns42"); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index 7cee174da..ed930b174 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -99,7 +99,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer case "test://Resource3": case "test://ResourceTemplate1": case "test://ResourceTemplate2": - return new ReadResourceResult() + return new ReadResourceResult { Contents = [new TextResourceContents { Text = request.Params?.Uri ?? "(null)" }] }; @@ -129,7 +129,7 @@ public async Task Can_List_And_Call_Registered_Resources() var resources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); Assert.Equal(5, resources.Count); - var resource = resources.First(t => t.Name == nameof(SimpleResources.SomeNeatDirectResource)); + var resource = resources.First(t => t.Name == "some_neat_direct_resource"); Assert.Equal("Some neat direct resource", resource.Description); var result = await resource.ReadAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -146,7 +146,7 @@ public async Task Can_List_And_Call_Registered_ResourceTemplates() var resources = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken); Assert.Equal(3, resources.Count); - var resource = resources.First(t => t.Name == nameof(SimpleResources.SomeNeatTemplatedResource)); + var resource = resources.First(t => t.Name == "some_neat_templated_resource"); Assert.Equal("Some neat resource with parameters", resource.Description); var result = await resource.ReadAsync(new Dictionary() { ["name"] = "hello" }, cancellationToken: TestContext.Current.CancellationToken); @@ -204,13 +204,13 @@ public async Task TitleAttributeProperty_PropagatedToTitle() var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(resources); Assert.NotEmpty(resources); - McpClientResource resource = resources.First(t => t.Name == nameof(SimpleResources.SomeNeatDirectResource)); + McpClientResource resource = resources.First(t => t.Name == "some_neat_direct_resource"); Assert.Equal("This is a title", resource.Title); var resourceTemplates = await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(resourceTemplates); Assert.NotEmpty(resourceTemplates); - McpClientResourceTemplate resourceTemplate = resourceTemplates.First(t => t.Name == nameof(SimpleResources.SomeNeatTemplatedResource)); + McpClientResourceTemplate resourceTemplate = resourceTemplates.First(t => t.Name == "some_neat_templated_resource"); Assert.Equal("This is another title", resourceTemplate.Title); } @@ -220,7 +220,7 @@ public async Task Throws_When_Resource_Fails() await using IMcpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( - $"resource://{nameof(SimpleResources.ThrowsException)}", + $"resource://mcp/{nameof(SimpleResources.ThrowsException)}", cancellationToken: TestContext.Current.CancellationToken)); } @@ -230,7 +230,7 @@ public async Task Throws_Exception_On_Unknown_Resource() await using IMcpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( - "test://NotRegisteredResource", + "test:///NotRegisteredResource", cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains("Resource not found", e.Message); @@ -268,8 +268,8 @@ public void Register_Resources_From_Current_Assembly() sc.AddMcpServer().WithResourcesFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); - Assert.Contains(services.GetServices(), t => t.ProtocolResource?.Uri == $"resource://{nameof(SimpleResources.SomeNeatDirectResource)}"); - Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://{nameof(SimpleResources.SomeNeatTemplatedResource)}{{?name}}"); + Assert.Contains(services.GetServices(), t => t.ProtocolResource?.Uri == $"resource://mcp/some_neat_direct_resource"); + Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/some_neat_templated_resource{{?name}}"); } [Fact] @@ -279,13 +279,13 @@ public void Register_Resources_From_Multiple_Sources() sc.AddMcpServer() .WithResources() .WithResources() - .WithResources([McpServerResource.Create(() => "42", new() { UriTemplate = "myResources://Returns42/{something}" })]); + .WithResources([McpServerResource.Create(() => "42", new() { UriTemplate = "myResources:///returns42/{something}" })]); IServiceProvider services = sc.BuildServiceProvider(); - Assert.Contains(services.GetServices(), t => t.ProtocolResource?.Uri == $"resource://{nameof(SimpleResources.SomeNeatDirectResource)}"); - Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://{nameof(SimpleResources.SomeNeatTemplatedResource)}{{?name}}"); - Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://{nameof(MoreResources.AnotherNeatDirectResource)}"); - Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate.UriTemplate == "myResources://Returns42/{something}"); + Assert.Contains(services.GetServices(), t => t.ProtocolResource?.Uri == $"resource://mcp/some_neat_direct_resource"); + Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/some_neat_templated_resource{{?name}}"); + Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/another_neat_direct_resource"); + Assert.Contains(services.GetServices(), t => t.ProtocolResourceTemplate.UriTemplate == "myResources:///returns42/{something}"); } [McpServerResourceType] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 997407120..38c688cce 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -95,7 +95,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer case "FirstCustomTool": case "SecondCustomTool": case "FinalCustomTool": - return new CallToolResult() + return new CallToolResult { Content = [new TextContentBlock { Text = $"{request.Params.Name}Result" }], }; @@ -126,8 +126,7 @@ public async Task Can_List_Registered_Tools() var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); - McpClientTool echoTool = tools.First(t => t.Name == "Echo"); - Assert.Equal("Echo", echoTool.Name); + McpClientTool echoTool = tools.First(t => t.Name == "echo"); Assert.Equal("Echoes the input back to the client.", echoTool.Description); Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString()); Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind); @@ -165,8 +164,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); - McpClientTool echoTool = tools.First(t => t.Name == "Echo"); - Assert.Equal("Echo", echoTool.Name); + McpClientTool echoTool = tools.First(t => t.Name == "echo"); Assert.Equal("Echoes the input back to the client.", echoTool.Description); Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString()); Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind); @@ -231,7 +229,7 @@ public async Task Can_Call_Registered_Tool() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "Echo", + "echo", new Dictionary() { ["message"] = "Peter" }, cancellationToken: TestContext.Current.CancellationToken); @@ -250,7 +248,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "EchoArray", + "echo_array", new Dictionary() { ["message"] = "Peter" }, cancellationToken: TestContext.Current.CancellationToken); @@ -274,7 +272,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "ReturnNull", + "return_null", cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); @@ -288,7 +286,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "ReturnJson", + "return_json", cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); @@ -305,7 +303,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "ReturnInteger", + "return_integer", cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result.Content); @@ -320,7 +318,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "EchoComplex", + "echo_complex", new Dictionary() { ["complex"] = JsonDocument.Parse("""{"Name": "Peter", "Age": 25}""").RootElement }, cancellationToken: TestContext.Current.CancellationToken); @@ -340,7 +338,7 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() for (int i = 0; i < 2; i++) { var result = await client.CallToolAsync( - nameof(EchoTool.GetCtorParameter), + "get_ctor_parameter", cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); @@ -366,7 +364,7 @@ public async Task Returns_IsError_Content_When_Tool_Fails() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "ThrowException", + "throw_exception", cancellationToken: TestContext.Current.CancellationToken); Assert.True(result.IsError); @@ -393,7 +391,7 @@ public async Task Returns_IsError_Missing_Parameter() await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( - "Echo", + "echo", cancellationToken: TestContext.Current.CancellationToken); Assert.True(result.IsError); @@ -436,7 +434,7 @@ public void Register_Tools_From_Current_Assembly() sc.AddMcpServer().WithToolsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); - Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "Echo"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "echo"); } [Theory] @@ -452,7 +450,7 @@ public void WithTools_Parameters_Satisfiable_From_DI(bool parameterInServices) sc.AddMcpServer().WithTools([typeof(EchoTool)], BuilderToolsJsonContext.Default.Options); IServiceProvider services = sc.BuildServiceProvider(); - McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "EchoComplex"); + McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "echo_complex"); if (parameterInServices) { Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, AIJsonUtilities.DefaultOptions)); @@ -495,7 +493,7 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime sc.AddMcpServer().WithToolsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); - McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "EchoComplex"); + McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "echo_complex"); if (lifetime is not null) { Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, AIJsonUtilities.DefaultOptions)); @@ -516,8 +514,7 @@ public async Task Recognizes_Parameter_Types() Assert.NotNull(tools); Assert.NotEmpty(tools); - var tool = tools.First(t => t.Name == "TestTool"); - Assert.Equal("TestTool", tool.Name); + var tool = tools.First(t => t.Name == "test_tool"); Assert.Empty(tool.Description!); Assert.Equal("object", tool.JsonSchema.GetProperty("type").GetString()); @@ -543,9 +540,9 @@ public void Register_Tools_From_Multiple_Sources() Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "double_echo"); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "DifferentName"); - Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodB"); - Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodC"); - Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodD"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "method_b"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "method_c"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "method_d"); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "Returns42"); } @@ -591,7 +588,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() Assert.NotNull(tools); Assert.NotEmpty(tools); - McpClientTool tool = tools.First(t => t.Name == nameof(EchoTool.EchoComplex)); + McpClientTool tool = tools.First(t => t.Name == "echo_complex"); Assert.Equal("This is a title", tool.Title); Assert.Equal("This is a title", tool.ProtocolTool.Title); @@ -607,7 +604,7 @@ public async Task HandlesIProgressParameter() Assert.NotNull(tools); Assert.NotEmpty(tools); - McpClientTool progressTool = tools.First(t => t.Name == nameof(EchoTool.SendsProgressNotifications)); + McpClientTool progressTool = tools.First(t => t.Name == "sends_progress_notifications"); TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); int remainingNotifications = 10; @@ -660,7 +657,7 @@ public async Task CancellationNotificationsPropagateToToolTokens() var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); Assert.NotEmpty(tools); - McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation)); + McpClientTool cancelableTool = tools.First(t => t.Name == "infinite_cancelable_operation"); var requestId = new RequestId(Guid.NewGuid().ToString()); var invokeTask = client.SendRequestAsync( @@ -671,7 +668,7 @@ public async Task CancellationNotificationsPropagateToToolTokens() await client.SendNotificationAsync( NotificationMethods.CancelledNotification, - parameters: new CancelledNotificationParams() + parameters: new CancelledNotificationParams { RequestId = requestId, }, diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index 8f90d2568..b940c1c7c 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -25,7 +25,7 @@ public async Task InjectScopedServiceAsArgument() await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); - var tool = tools.First(t => t.Name == nameof(EchoTool.EchoComplex)); + var tool = tools.First(t => t.Name == "echo_complex"); Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.JsonSchema, McpJsonUtilities.DefaultOptions)); int startingConstructed = ComplexObject.Constructed; diff --git a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs index 4def27938..80c6b1ed9 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs @@ -54,7 +54,7 @@ public async Task CancellationPropagation_RequestingCancellationCancelsPendingRe await using var client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - var waitTool = tools.First(t => t.Name == nameof(WaitForCancellation)); + var waitTool = tools.First(t => t.Name == "wait_for_cancellation"); CancellationTokenSource cts = new(); var waitTask = waitTool.InvokeAsync(cancellationToken: cts.Token); diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs new file mode 100644 index 000000000..f44743916 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs @@ -0,0 +1,147 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Configuration; + +public partial class ElicitationTests : ClientServerTestBase +{ + public ElicitationTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithCallToolHandler(async (request, cancellationToken) => + { + Assert.Equal("TestElicitation", request.Params?.Name); + + var result = await request.Server.ElicitAsync( + new() + { + Message = "Please provide more information.", + RequestedSchema = new() + { + Properties = new Dictionary() + { + ["prop1"] = new ElicitRequestParams.StringSchema + { + Title = "title1", + MinLength = 1, + MaxLength = 100, + }, + ["prop2"] = new ElicitRequestParams.NumberSchema + { + Description = "description2", + Minimum = 0, + Maximum = 1000, + }, + ["prop3"] = new ElicitRequestParams.BooleanSchema + { + Title = "title3", + Description = "description4", + Default = true, + }, + ["prop4"] = new ElicitRequestParams.EnumSchema + { + Enum = ["option1", "option2", "option3"], + EnumNames = ["Name1", "Name2", "Name3"], + }, + }, + }, + }, + CancellationToken.None); + + Assert.Equal("accept", result.Action); + + return new CallToolResult + { + Content = [new TextContentBlock { Text = "success" }], + }; + }); + } + + [Fact] + public async Task Can_Elicit_Information() + { + await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + { + Capabilities = new() + { + Elicitation = new() + { + ElicitationHandler = async (request, cancellationtoken) => + { + Assert.NotNull(request); + Assert.Equal("Please provide more information.", request.Message); + Assert.Equal(4, request.RequestedSchema.Properties.Count); + + foreach (var entry in request.RequestedSchema.Properties) + { + switch (entry.Key) + { + case "prop1": + var primitiveString = Assert.IsType(entry.Value); + Assert.Equal("title1", primitiveString.Title); + Assert.Equal(1, primitiveString.MinLength); + Assert.Equal(100, primitiveString.MaxLength); + break; + + case "prop2": + var primitiveNumber = Assert.IsType(entry.Value); + Assert.Equal("description2", primitiveNumber.Description); + Assert.Equal(0, primitiveNumber.Minimum); + Assert.Equal(1000, primitiveNumber.Maximum); + break; + + case "prop3": + var primitiveBool = Assert.IsType(entry.Value); + Assert.Equal("title3", primitiveBool.Title); + Assert.Equal("description4", primitiveBool.Description); + Assert.True(primitiveBool.Default); + break; + + case "prop4": + var primitiveEnum = Assert.IsType(entry.Value); + Assert.Equal(["option1", "option2", "option3"], primitiveEnum.Enum); + Assert.Equal(["Name1", "Name2", "Name3"], primitiveEnum.EnumNames); + break; + + default: + Assert.Fail($"Unknown property: {entry.Key}"); + break; + } + } + + return new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["prop1"] = (JsonElement)JsonSerializer.Deserialize(""" + "string result" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop2"] = (JsonElement)JsonSerializer.Deserialize(""" + 42 + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop3"] = (JsonElement)JsonSerializer.Deserialize(""" + true + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop4"] = (JsonElement)JsonSerializer.Deserialize(""" + "option2" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + }, + }; + }, + }, + }, + }); + + var result = await client.CallToolAsync("TestElicitation", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("success", (result.Content[0] as TextContentBlock)?.Text); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index ca1bfe97b..90998e24b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -1,9 +1,11 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Primitives; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; using System.ComponentModel; +using System.Diagnostics; using System.Reflection; using System.Text.Json; using System.Text.Json.Nodes; @@ -44,6 +46,58 @@ public async Task SupportsIMcpServer() Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } + [Fact] + public async Task SupportsCtorInjection() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + Mock mockServer = new(); + mockServer.SetupGet(s => s.Services).Returns(services); + + MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestPrompt)); + Assert.NotNull(testMethod); + McpServerPrompt prompt = McpServerPrompt.Create(testMethod, r => + { + Assert.NotNull(r.Services); + return ActivatorUtilities.CreateInstance(r.Services, typeof(HasCtorWithSpecialParameters)); + }, new() { Services = services }); + + var result = await prompt.GetAsync( + new RequestContext(mockServer.Object), + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Messages); + Assert.Single(result.Messages); + Assert.Equal("True True True True", Assert.IsType(result.Messages[0].Content).Text); + } + + private sealed class HasCtorWithSpecialParameters + { + private readonly MyService _ms; + private readonly IMcpServer _server; + private readonly RequestContext _request; + private readonly IProgress _progress; + + public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + { + Assert.NotNull(ms); + Assert.NotNull(server); + Assert.NotNull(request); + Assert.NotNull(progress); + + _ms = ms; + _server = server; + _request = request; + _progress = progress; + } + + public string TestPrompt() => $"{_ms is not null} {_server is not null} {_request is not null} {_progress is not null}"; + } + [Fact] public async Task SupportsServiceFromDI() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 3e25d4e92..fb0772d04 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -122,159 +122,188 @@ public void Create_InvalidArgs_Throws() public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() { const string Name = "Hello"; + McpServerResource t; ReadResourceResult? result; IMcpServer server = new Mock().Object; t = McpServerResource.Create(() => "42", new() { Name = Name }); - Assert.Equal($"resource://{Name}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}" } }, + new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((IMcpServer server) => "42", new() { Name = Name }); - Assert.Equal($"resource://{Name}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}" } }, + new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((string arg1) => arg1, new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?arg1}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?arg1}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?arg1=wOrLd" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("wOrLd", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((string arg1, string? arg2 = null) => arg1 + arg2, new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?arg1,arg2}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?arg1,arg2}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?arg1=wo&arg2=rld" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("world", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((object a1, bool a2, char a3, byte a4, sbyte a5) => a1.ToString() + a2 + a3 + a4 + a5, new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("hiTrues1234", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((ushort a1, short a2, uint a3, int a4, ulong a5) => (a1 + a2 + a3 + a4 + (long)a5).ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((long a1, float a2, double a3, decimal a4, TimeSpan a5) => a5.ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((DateTime a1, DateTimeOffset a2, Uri a3, Guid a4, Version a5) => a4.ToString("N") + a5, new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e411.2.3.4", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((Half a2, Int128 a3, UInt128 a4, IntPtr a5) => (a3 + (Int128)a4 + a5).ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((UIntPtr a1, DateOnly a2, TimeOnly a3) => a1.ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((bool? a2, char? a3, byte? a4, sbyte? a5) => a2?.ToString() + a3 + a4 + a5, new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("Trues1234", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((ushort? a1, short? a2, uint? a3, int? a4, ulong? a5) => (a1 + a2 + a3 + a4 + (long?)a5).ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((long? a1, float? a2, double? a3, decimal? a4, TimeSpan? a5) => a5?.ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((DateTime? a1, DateTimeOffset? a2, Guid? a4) => a4?.ToString("N"), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a4}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a4}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e41", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((Half? a2, Int128? a3, UInt128? a4, IntPtr? a5) => (a3 + (Int128?)a4 + a5).ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); t = McpServerResource.Create((UIntPtr? a1, DateOnly? a2, TimeOnly? a3) => a1?.ToString(), new() { Name = Name }); - Assert.Equal($"resource://{Name}{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://{Name}?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); } [Theory] - [InlineData("resource://Hello?arg1=42&arg2=84")] - [InlineData("resource://Hello?arg1=42&arg2=84&arg3=123")] - [InlineData("resource://Hello#fragment")] + [InlineData("resource://mcp/Hello?arg1=42&arg2=84")] + [InlineData("resource://mcp/Hello?arg1=42&arg2=84&arg3=123")] + [InlineData("resource://mcp/Hello#fragment")] public async Task UriTemplate_NonMatchingUri_ReturnsNull(string uri) { McpServerResource t = McpServerResource.Create((string arg1) => arg1, new() { Name = "Hello" }); - Assert.Equal("resource://Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal("resource://mcp/Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); Assert.Null(await t.ReadAsync( new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } [Theory] - [InlineData("resource://Hello?arg1=test")] - [InlineData("resource://Hello?arg2=test")] + [InlineData("resource://MyCoolResource", "resource://mycoolresource")] + [InlineData("resource://MyCoolResource{?arg1}", "resource://mycoolresource?arg1=42")] + public async Task UriTemplate_IsHostCaseInsensitive(string actualUri, string queriedUri) + { + McpServerResource t = McpServerResource.Create(() => "resource", new() { UriTemplate = actualUri }); + Assert.NotNull(await t.ReadAsync( + new RequestContext(new Mock().Object) { Params = new() { Uri = queriedUri } }, + TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task ResourceCollection_UsesCaseInsensitiveHostLookup() + { + McpServerResource t1 = McpServerResource.Create(() => "resource", new() { UriTemplate = "resource://MyCoolResource" }); + McpServerResource t2 = McpServerResource.Create(() => "resource", new() { UriTemplate = "resource://MyCoolResource2" }); + McpServerResourceCollection collection = new() { t1, t2 }; + Assert.True(collection.TryGetPrimitive("resource://mycoolresource", out McpServerResource? result)); + Assert.Same(t1, result); + } + + [Fact] + public void MimeType_DefaultsToOctetStream() + { + McpServerResource t = McpServerResource.Create(() => "resource", new() { Name = "My Cool Resource" }); + Assert.Equal("application/octet-stream", t.ProtocolResourceTemplate.MimeType); + } + + [Theory] + [InlineData("resource://mcp/Hello?arg1=test")] + [InlineData("resource://mcp/Hello?arg2=test")] public async Task UriTemplate_MissingParameter_Throws(string uri) { McpServerResource t = McpServerResource.Create((string arg1, int arg2) => arg1, new() { Name = "Hello" }); - Assert.Equal("resource://Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); await Assert.ThrowsAsync(async () => await t.ReadAsync( new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); @@ -284,30 +313,30 @@ await Assert.ThrowsAsync(async () => await t.ReadAsync( public async Task UriTemplate_MissingOptionalParameter_Succeeds() { McpServerResource t = McpServerResource.Create((string? arg1 = null, int? arg2 = null) => arg1 + arg2, new() { Name = "Hello" }); - Assert.Equal("resource://Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); + Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); ReadResourceResult? result; result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://Hello" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://Hello?arg1=first" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://Hello?arg2=42" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://Hello?arg1=first&arg2=42" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first42", ((TextResourceContents)result.Contents[0]).Text); @@ -325,12 +354,65 @@ public async Task SupportsIMcpServer() }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); } + [Fact] + public async Task SupportsCtorInjection() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + Mock mockServer = new(); + mockServer.SetupGet(s => s.Services).Returns(services); + + MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestResource)); + Assert.NotNull(testMethod); + McpServerResource tool = McpServerResource.Create(testMethod, r => + { + Assert.NotNull(r.Services); + return ActivatorUtilities.CreateInstance(r.Services, typeof(HasCtorWithSpecialParameters)); + }, new() { Services = services }); + + var result = await tool.ReadAsync( + new RequestContext(mockServer.Object) { Params = new() { Uri = "https://something/" } }, + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Contents); + Assert.Single(result.Contents); + Assert.Equal("True True True True", Assert.IsType(result.Contents[0]).Text); + } + + private sealed class HasCtorWithSpecialParameters + { + private readonly MyService _ms; + private readonly IMcpServer _server; + private readonly RequestContext _request; + private readonly IProgress _progress; + + public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + { + Assert.NotNull(ms); + Assert.NotNull(server); + Assert.NotNull(request); + Assert.NotNull(progress); + + _ms = ms; + _server = server; + _request = request; + _progress = progress; + } + + [McpServerResource(UriTemplate = "https://something/")] + public string TestResource() => $"{_ms is not null} {_server is not null} {_request is not null} {_progress is not null}"; + } + [Theory] [InlineData(ServiceLifetime.Singleton)] [InlineData(ServiceLifetime.Scoped)] @@ -376,11 +458,11 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken)); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Services = services, Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -402,7 +484,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services, Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -436,7 +518,7 @@ public async Task CanReturnReadResult() return new ReadResourceResult { Contents = new List { new TextResourceContents { Text = "hello" } } }; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -453,7 +535,7 @@ public async Task CanReturnResourceContents() return new TextResourceContents { Text = "hello" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -474,7 +556,7 @@ public async Task CanReturnCollectionOfResourceContents() ]; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -492,7 +574,7 @@ public async Task CanReturnString() return "42"; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -509,7 +591,7 @@ public async Task CanReturnCollectionOfStrings() return new List { "42", "43" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -527,7 +609,7 @@ public async Task CanReturnDataContent() return new DataContent(new byte[] { 0, 1, 2 }, "application/octet-stream"); }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -542,14 +624,14 @@ public async Task CanReturnCollectionOfAIContent() McpServerResource resource = McpServerResource.Create((IMcpServer server) => { Assert.Same(mockServer.Object, server); - return new List() + return new List { new TextContent("hello!"), new DataContent(new byte[] { 4, 5, 6 }, "application/json"), }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://Test" } }, + new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 9e7eca9f6..260b9bdd7 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -586,7 +586,7 @@ public async Task Can_SendMessage_Before_RunAsync() await using var transport = new TestServerTransport(); await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); - var logNotification = new JsonRpcNotification() + var logNotification = new JsonRpcNotification { Method = NotificationMethods.LoggingMessageNotification }; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 742133413..0f67f2a58 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -50,6 +50,58 @@ public async Task SupportsIMcpServer() Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } + [Fact] + public async Task SupportsCtorInjection() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + Mock mockServer = new(); + mockServer.SetupGet(s => s.Services).Returns(services); + + MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestTool)); + Assert.NotNull(testMethod); + McpServerTool tool = McpServerTool.Create(testMethod, r => + { + Assert.NotNull(r.Services); + return ActivatorUtilities.CreateInstance(r.Services, typeof(HasCtorWithSpecialParameters)); + }, new() { Services = services }); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object), + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.Single(result.Content); + Assert.Equal("True True True True", Assert.IsType(result.Content[0]).Text); + } + + private sealed class HasCtorWithSpecialParameters + { + private readonly MyService _ms; + private readonly IMcpServer _server; + private readonly RequestContext _request; + private readonly IProgress _progress; + + public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + { + Assert.NotNull(ms); + Assert.NotNull(server); + Assert.NotNull(request); + Assert.NotNull(progress); + + _ms = ms; + _server = server; + _request = request; + _progress = progress; + } + + public string TestTool() => $"{_ms is not null} {_server is not null} {_request is not null} {_progress is not null}"; + } + [Theory] [InlineData(ServiceLifetime.Singleton)] [InlineData(ServiceLifetime.Scoped)] diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index cc6b4e0a1..3ff504304 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -35,7 +35,7 @@ public void Constructor_Throws_For_Null_Options() [Fact] public void Constructor_Throws_For_Null_HttpClient() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, null!, LoggerFactory)); + var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); }