diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 57cba171fb..9fd6c03c72 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,2 +1,5 @@ # Applied 120 line-length rule to all files: https://github.com/modelcontextprotocol/python-sdk/pull/856 543961968c0634e93d919d509cce23a1d6a56c21 + +# Added 100% code coverage baseline with pragma comments: https://github.com/modelcontextprotocol/python-sdk/pull/1553 +89e9c43acf7e23cf766357d776ec1ce63ac2c58e diff --git a/.gitattribute b/.gitattribute index d9d7d2c7ee..0ab3744850 100644 --- a/.gitattribute +++ b/.gitattribute @@ -1,2 +1,2 @@ # Generated -uv.lock linguist-generated=true \ No newline at end of file +uv.lock linguist-generated=true diff --git a/.github/workflows/comment-on-release.yml b/.github/workflows/comment-on-release.yml new file mode 100644 index 0000000000..f8b1751e53 --- /dev/null +++ b/.github/workflows/comment-on-release.yml @@ -0,0 +1,149 @@ +name: Comment on PRs in Release + +on: + release: + types: [published] + +permissions: + pull-requests: write + contents: read + +jobs: + comment-on-prs: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Get previous release + id: previous_release + uses: actions/github-script@v7 + with: + script: | + const currentTag = '${{ github.event.release.tag_name }}'; + + // Get all releases + const { data: releases } = await github.rest.repos.listReleases({ + owner: context.repo.owner, + repo: context.repo.repo, + per_page: 100 + }); + + // Find current release index + const currentIndex = releases.findIndex(r => r.tag_name === currentTag); + + if (currentIndex === -1) { + console.log('Current release not found in list'); + return null; + } + + // Get previous release (next in the list since they're sorted by date desc) + const previousRelease = releases[currentIndex + 1]; + + if (!previousRelease) { + console.log('No previous release found, this might be the first release'); + return null; + } + + console.log(`Found previous release: ${previousRelease.tag_name}`); + + return previousRelease.tag_name; + + - name: Get merged PRs between releases + id: get_prs + uses: actions/github-script@v7 + with: + script: | + const currentTag = '${{ github.event.release.tag_name }}'; + const previousTag = ${{ steps.previous_release.outputs.result }}; + + if (!previousTag) { + console.log('No previous release found, skipping'); + return []; + } + + console.log(`Finding PRs between ${previousTag} and ${currentTag}`); + + // Get commits between previous and current release + const comparison = await github.rest.repos.compareCommits({ + owner: context.repo.owner, + repo: context.repo.repo, + base: previousTag, + head: currentTag + }); + + const commits = comparison.data.commits; + console.log(`Found ${commits.length} commits`); + + // Get PRs associated with each commit using GitHub API + const prNumbers = new Set(); + + for (const commit of commits) { + try { + const { data: prs } = await github.rest.repos.listPullRequestsAssociatedWithCommit({ + owner: context.repo.owner, + repo: context.repo.repo, + commit_sha: commit.sha + }); + + for (const pr of prs) { + if (pr.merged_at) { + prNumbers.add(pr.number); + console.log(`Found merged PR: #${pr.number}`); + } + } + } catch (error) { + console.log(`Failed to get PRs for commit ${commit.sha}: ${error.message}`); + } + } + + console.log(`Found ${prNumbers.size} merged PRs`); + return Array.from(prNumbers); + + - name: Comment on PRs + uses: actions/github-script@v7 + with: + script: | + const prNumbers = ${{ steps.get_prs.outputs.result }}; + const releaseTag = '${{ github.event.release.tag_name }}'; + const releaseUrl = '${{ github.event.release.html_url }}'; + + const comment = `This pull request is included in [${releaseTag}](${releaseUrl})`; + + let commentedCount = 0; + + for (const prNumber of prNumbers) { + try { + // Check if we've already commented on this PR for this release + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + per_page: 100 + }); + + const alreadyCommented = comments.some(c => + c.user.type === 'Bot' && c.body.includes(releaseTag) + ); + + if (alreadyCommented) { + console.log(`Skipping PR #${prNumber} - already commented for ${releaseTag}`); + continue; + } + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: comment + }); + commentedCount++; + console.log(`Successfully commented on PR #${prNumber}`); + } catch (error) { + console.error(`Failed to comment on PR #${prNumber}:`, error.message); + } + } + + console.log(`Commented on ${commentedCount} of ${prNumbers.length} PRs`); diff --git a/.github/workflows/main-checks.yml b/.github/workflows/main-checks.yml index 6f38043cdd..e2b2a97a14 100644 --- a/.github/workflows/main-checks.yml +++ b/.github/workflows/main-checks.yml @@ -5,6 +5,7 @@ on: branches: - main - "v*.*.*" + - "v1.x" tags: - "v*.*.*" diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml index f23aaa92fe..befe44d31c 100644 --- a/.github/workflows/publish-docs-manually.yml +++ b/.github/workflows/publish-docs-manually.yml @@ -19,7 +19,7 @@ jobs: uses: astral-sh/setup-uv@v3 with: enable-cache: true - version: 0.7.2 + version: 0.9.5 - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - uses: actions/cache@v4 diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 0d9eb2de0f..59ede84172 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -16,7 +16,7 @@ jobs: uses: astral-sh/setup-uv@v3 with: enable-cache: true - version: 0.7.2 + version: 0.9.5 - name: Set up Python 3.12 run: uv python install 3.12 @@ -68,7 +68,7 @@ jobs: uses: astral-sh/setup-uv@v3 with: enable-cache: true - version: 0.7.2 + version: 0.9.5 - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - uses: actions/cache@v4 diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index 7d6ec5d610..531487db5a 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -13,56 +13,63 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: astral-sh/setup-uv@v5 + - uses: astral-sh/setup-uv@v7 with: enable-cache: true - version: 0.7.2 - + version: 0.9.5 - name: Install dependencies run: uv sync --frozen --all-extras --python 3.10 - - uses: pre-commit/action@v3.0.0 + - uses: pre-commit/action@v3.0.1 with: extra_args: --all-files --verbose env: SKIP: no-commit-to-branch test: + name: test (${{ matrix.python-version }}, ${{ matrix.dep-resolution.name }}, ${{ matrix.os }}) runs-on: ${{ matrix.os }} timeout-minutes: 10 continue-on-error: true strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] - dep-resolution: ["lowest-direct", "highest"] + dep-resolution: + - name: lowest-direct + install-flags: "--upgrade --resolution lowest-direct" + - name: highest + install-flags: "--upgrade --resolution highest" os: [ubuntu-latest, windows-latest] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v7 with: enable-cache: true - version: 0.7.2 + version: 0.9.5 - name: Install the project - run: uv sync --frozen --all-extras --python ${{ matrix.python-version }} --resolution ${{ matrix.dep-resolution }} + run: uv sync ${{ matrix.dep-resolution.install-flags }} --all-extras --python ${{ matrix.python-version }} - - name: Run pytest - run: uv run --frozen --no-sync pytest + - name: Run pytest with coverage + run: | + uv run --frozen --no-sync coverage run -m pytest + uv run --frozen --no-sync coverage combine + uv run --frozen --no-sync coverage report readme-snippets: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: astral-sh/setup-uv@v5 + - uses: astral-sh/setup-uv@v7 with: enable-cache: true - version: 0.7.2 + version: 0.9.5 - name: Install dependencies run: uv sync --frozen --all-extras --python 3.10 diff --git a/.gitignore b/.gitignore index 429a0375ae..2478cac4b3 100644 --- a/.gitignore +++ b/.gitignore @@ -89,7 +89,7 @@ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28da65c608..c06b9028da 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,11 @@ fail_fast: true repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: end-of-file-fixer + - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c18937f5b3..dd60f39ce5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,9 +23,14 @@ uv tool install pre-commit --with pre-commit-uv --force-reinstall ## Development Workflow 1. Choose the correct branch for your changes: - - For bug fixes to a released version: use the latest release branch (e.g. v1.1.x for 1.1.3) - - For new features: use the main branch (which will become the next minor/major version) - - If unsure, ask in an issue first + + | Change Type | Target Branch | Example | + |-------------|---------------|---------| + | New features, breaking changes | `main` | New APIs, refactors | + | Security fixes for v1 | `v1.x` | Critical patches | + | Bug fixes for v1 | `v1.x` | Non-breaking fixes | + + > **Note:** `main` is the v2 development branch. Breaking changes are welcome on `main`. The `v1.x` branch receives only security and critical bug fixes. 2. Create a new branch from your chosen base branch diff --git a/README.md b/README.md index 48d1c1742f..e7a6e955b9 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ [protocol-badge]: https://img.shields.io/badge/protocol-modelcontextprotocol.io-blue.svg [protocol-url]: https://modelcontextprotocol.io [spec-badge]: https://img.shields.io/badge/spec-spec.modelcontextprotocol.io-blue.svg -[spec-url]: https://spec.modelcontextprotocol.io +[spec-url]: https://modelcontextprotocol.io/specification/latest ## Overview @@ -132,14 +132,14 @@ Let's create a simple MCP server that exposes a calculator tool and some data: """ FastMCP quickstart example. -cd to the `examples/snippets/clients` directory and run: - uv run server fastmcp_quickstart stdio +Run from the repository root: + uv run examples/snippets/servers/fastmcp_quickstart.py """ from mcp.server.fastmcp import FastMCP # Create an MCP server -mcp = FastMCP("Demo") +mcp = FastMCP("Demo", json_response=True) # Add an addition tool @@ -167,23 +167,36 @@ def greet_user(name: str, style: str = "friendly") -> str: } return f"{styles.get(style, styles['friendly'])} for someone named {name}." + + +# Run with streamable HTTP transport +if __name__ == "__main__": + mcp.run(transport="streamable-http") ``` _Full example: [examples/snippets/servers/fastmcp_quickstart.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/fastmcp_quickstart.py)_ -You can install this server in [Claude Desktop](https://claude.ai/download) and interact with it right away by running: +You can install this server in [Claude Code](https://docs.claude.com/en/docs/claude-code/mcp) and interact with it right away. First, run the server: ```bash -uv run mcp install server.py +uv run --with mcp examples/snippets/servers/fastmcp_quickstart.py ``` -Alternatively, you can test it with the MCP Inspector: +Then add it to Claude Code: ```bash -uv run mcp dev server.py +claude mcp add --transport http my-server http://localhost:8000/mcp +``` + +Alternatively, you can test it with the MCP Inspector. Start the server as above, then in a separate terminal: + +```bash +npx -y @modelcontextprotocol/inspector ``` +In the inspector UI, connect to `http://localhost:8000/mcp`. + ## What is MCP? The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. MCP servers can: @@ -383,6 +396,61 @@ causes the tool to be classified as structured _and this is undesirable_, the classification can be suppressed by passing `structured_output=False` to the `@tool` decorator. +##### Advanced: Direct CallToolResult + +For full control over tool responses including the `_meta` field (for passing data to client applications without exposing it to the model), you can return `CallToolResult` directly: + + +```python +"""Example showing direct CallToolResult return for advanced control.""" + +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server.fastmcp import FastMCP +from mcp.types import CallToolResult, TextContent + +mcp = FastMCP("CallToolResult Example") + + +class ValidationModel(BaseModel): + """Model for validating structured output.""" + + status: str + data: dict[str, int] + + +@mcp.tool() +def advanced_tool() -> CallToolResult: + """Return CallToolResult directly for full control including _meta field.""" + return CallToolResult( + content=[TextContent(type="text", text="Response visible to the model")], + _meta={"hidden": "data for client applications only"}, + ) + + +@mcp.tool() +def validated_tool() -> Annotated[CallToolResult, ValidationModel]: + """Return CallToolResult with structured output validation.""" + return CallToolResult( + content=[TextContent(type="text", text="Validated response")], + structuredContent={"status": "success", "data": {"result": 42}}, + _meta={"internal": "metadata"}, + ) + + +@mcp.tool() +def empty_result_tool() -> CallToolResult: + """For empty results, return CallToolResult with empty content.""" + return CallToolResult(content=[]) +``` + +_Full example: [examples/snippets/servers/direct_call_tool_result.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/direct_call_tool_result.py)_ + + +**Important:** `CallToolResult` must always be returned (no `Optional` or `Union`). For empty results, use `CallToolResult(content=[])`. For optional simple types, use `str | None` without `CallToolResult`. + ```python """Example showing structured output with tools.""" @@ -740,10 +808,21 @@ Request additional information from users. This example shows an Elicitation dur ```python +"""Elicitation examples demonstrating form and URL mode elicitation. + +Form mode elicitation collects structured, non-sensitive data through a schema. +URL mode elicitation directs users to external URLs for sensitive operations +like OAuth flows, credential collection, or payment processing. +""" + +import uuid + from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ElicitRequestURLParams mcp = FastMCP(name="Elicitation Example") @@ -760,7 +839,10 @@ class BookingPreferences(BaseModel): @mcp.tool() async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: - """Book a table with date availability check.""" + """Book a table with date availability check. + + This demonstrates form mode elicitation for collecting non-sensitive user input. + """ # Check if date is available if date == "2024-12-25": # Date unavailable - ask user for alternative @@ -777,6 +859,54 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS # Date available return f"[SUCCESS] Booked for {date} at {time}" + + +@mcp.tool() +async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: + """Process a secure payment requiring URL confirmation. + + This demonstrates URL mode elicitation using ctx.elicit_url() for + operations that require out-of-band user interaction. + """ + elicitation_id = str(uuid.uuid4()) + + result = await ctx.elicit_url( + message=f"Please confirm payment of ${amount:.2f}", + url=f"https://payments.example.com/confirm?amount={amount}&id={elicitation_id}", + elicitation_id=elicitation_id, + ) + + if result.action == "accept": + # In a real app, the payment confirmation would happen out-of-band + # and you'd verify the payment status from your backend + return f"Payment of ${amount:.2f} initiated - check your browser to complete" + elif result.action == "decline": + return "Payment declined by user" + return "Payment cancelled" + + +@mcp.tool() +async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + """Connect to a third-party service requiring OAuth authorization. + + This demonstrates the "throw error" pattern using UrlElicitationRequiredError. + Use this pattern when the tool cannot proceed without user authorization. + """ + elicitation_id = str(uuid.uuid4()) + + # Raise UrlElicitationRequiredError to signal that the client must complete + # a URL elicitation before this request can be processed. + # The MCP framework will convert this to a -32042 error response. + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message=f"Authorization required to connect to {service_name}", + url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", + elicitationId=elicitation_id, + ) + ] + ) ``` _Full example: [examples/snippets/servers/elicitation.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/elicitation.py)_ @@ -818,6 +948,7 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: max_tokens=100, ) + # Since we're not passing tools param, result.content is single content if result.content.type == "text": return result.content.text return str(result.content) @@ -888,6 +1019,7 @@ class SimpleTokenVerifier(TokenVerifier): # Create FastMCP instance as a Resource Server mcp = FastMCP( "Weather Service", + json_response=True, # Token verifier for authentication token_verifier=SimpleTokenVerifier(), # Auth settings for RFC 9728 Protected Resource Metadata @@ -1103,7 +1235,7 @@ Note that `uv run mcp run` or `uv run mcp dev` only supports server using FastMC ### Streamable HTTP Transport -> **Note**: Streamable HTTP transport is superseding SSE transport for production deployments. +> **Note**: Streamable HTTP transport is the recommended transport for production deployments. Use `stateless_http=True` and `json_response=True` for optimal scalability. ```python @@ -1114,15 +1246,15 @@ Run from the repository root: from mcp.server.fastmcp import FastMCP -# Stateful server (maintains session state) -mcp = FastMCP("StatefulServer") +# Stateless server with JSON responses (recommended) +mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) # Other configuration options: -# Stateless server (no session persistence) +# Stateless server with SSE streaming responses # mcp = FastMCP("StatelessServer", stateless_http=True) -# Stateless server (no session persistence, no sse stream with supported client) -# mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) +# Stateful server with session persistence +# mcp = FastMCP("StatefulServer") # Add a simple tool to demonstrate the server @@ -1157,7 +1289,7 @@ from starlette.routing import Mount from mcp.server.fastmcp import FastMCP # Create the Echo server -echo_mcp = FastMCP(name="EchoServer", stateless_http=True) +echo_mcp = FastMCP(name="EchoServer", stateless_http=True, json_response=True) @echo_mcp.tool() @@ -1167,7 +1299,7 @@ def echo(message: str) -> str: # Create the Math server -math_mcp = FastMCP(name="MathServer", stateless_http=True) +math_mcp = FastMCP(name="MathServer", stateless_http=True, json_response=True) @math_mcp.tool() @@ -1262,13 +1394,15 @@ Run from the repository root: uvicorn examples.snippets.servers.streamable_http_basic_mounting:app --reload """ +import contextlib + from starlette.applications import Starlette from starlette.routing import Mount from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = FastMCP("My App") +mcp = FastMCP("My App", json_response=True) @mcp.tool() @@ -1277,11 +1411,19 @@ def hello() -> str: return "Hello from MCP!" +# Create a lifespan context manager to run the session manager +@contextlib.asynccontextmanager +async def lifespan(app: Starlette): + async with mcp.session_manager.run(): + yield + + # Mount the StreamableHTTP server to the existing ASGI server app = Starlette( routes=[ Mount("/", app=mcp.streamable_http_app()), - ] + ], + lifespan=lifespan, ) ``` @@ -1299,13 +1441,15 @@ Run from the repository root: uvicorn examples.snippets.servers.streamable_http_host_mounting:app --reload """ +import contextlib + from starlette.applications import Starlette from starlette.routing import Host from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = FastMCP("MCP Host App") +mcp = FastMCP("MCP Host App", json_response=True) @mcp.tool() @@ -1314,11 +1458,19 @@ def domain_info() -> str: return "This is served from mcp.acme.corp" +# Create a lifespan context manager to run the session manager +@contextlib.asynccontextmanager +async def lifespan(app: Starlette): + async with mcp.session_manager.run(): + yield + + # Mount using Host-based routing app = Starlette( routes=[ Host("mcp.acme.corp", app=mcp.streamable_http_app()), - ] + ], + lifespan=lifespan, ) ``` @@ -1336,14 +1488,16 @@ Run from the repository root: uvicorn examples.snippets.servers.streamable_http_multiple_servers:app --reload """ +import contextlib + from starlette.applications import Starlette from starlette.routing import Mount from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -api_mcp = FastMCP("API Server") -chat_mcp = FastMCP("Chat Server") +api_mcp = FastMCP("API Server", json_response=True) +chat_mcp = FastMCP("Chat Server", json_response=True) @api_mcp.tool() @@ -1363,12 +1517,23 @@ def send_message(message: str) -> str: api_mcp.settings.streamable_http_path = "/" chat_mcp.settings.streamable_http_path = "/" + +# Create a combined lifespan to manage both session managers +@contextlib.asynccontextmanager +async def lifespan(app: Starlette): + async with contextlib.AsyncExitStack() as stack: + await stack.enter_async_context(api_mcp.session_manager.run()) + await stack.enter_async_context(chat_mcp.session_manager.run()) + yield + + # Mount the servers app = Starlette( routes=[ Mount("/api", app=api_mcp.streamable_http_app()), Mount("/chat", app=chat_mcp.streamable_http_app()), - ] + ], + lifespan=lifespan, ) ``` @@ -1393,7 +1558,11 @@ from mcp.server.fastmcp import FastMCP # Configure streamable_http_path during initialization # This server will mount at the root of wherever it's mounted -mcp_at_root = FastMCP("My Server", streamable_http_path="/") +mcp_at_root = FastMCP( + "My Server", + json_response=True, + streamable_http_path="/", +) @mcp_at_root.tool() @@ -1769,14 +1938,93 @@ if __name__ == "__main__": _Full example: [examples/snippets/servers/lowlevel/structured_output.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/structured_output.py)_ -Tools can return data in three ways: +Tools can return data in four ways: 1. **Content only**: Return a list of content blocks (default behavior before spec revision 2025-06-18) 2. **Structured data only**: Return a dictionary that will be serialized to JSON (Introduced in spec revision 2025-06-18) 3. **Both**: Return a tuple of (content, structured_data) preferred option to use for backwards compatibility +4. **Direct CallToolResult**: Return `CallToolResult` directly for full control (including `_meta` field) When an `outputSchema` is defined, the server automatically validates the structured output against the schema. This ensures type safety and helps catch errors early. +##### Returning CallToolResult Directly + +For full control over the response including the `_meta` field (for passing data to client applications without exposing it to the model), return `CallToolResult` directly: + + +```python +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py +""" + +import asyncio +from typing import Any + +import mcp.server.stdio +import mcp.types as types +from mcp.server.lowlevel import NotificationOptions, Server +from mcp.server.models import InitializationOptions + +server = Server("example-server") + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + """List available tools.""" + return [ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + inputSchema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: + """Handle tool calls by returning CallToolResult directly.""" + if name == "advanced_tool": + message = str(arguments.get("message", "")) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Processed: {message}")], + structuredContent={"result": "success", "message": message}, + _meta={"hidden": "data for client applications only"}, + ) + + raise ValueError(f"Unknown tool: {name}") + + +async def run(): + """Run the server.""" + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="example", + server_version="0.1.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + +if __name__ == "__main__": + asyncio.run(run()) +``` + +_Full example: [examples/snippets/servers/lowlevel/direct_call_tool_result.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/lowlevel/direct_call_tool_result.py)_ + + +**Note:** When returning `CallToolResult`, you bypass the automatic content/structured conversion. You must construct the complete response yourself. + ### Pagination (Advanced) For servers that need to handle large datasets, the low-level server provides paginated versions of list operations. This is an optional optimization - most servers won't need pagination unless they're dealing with hundreds or thousands of items. @@ -1840,7 +2088,7 @@ import asyncio from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.types import Resource +from mcp.types import PaginatedRequestParams, Resource async def list_all_resources() -> None: @@ -1857,7 +2105,7 @@ async def list_all_resources() -> None: while True: # Fetch a page of resources - result = await session.list_resources(cursor=cursor) + result = await session.list_resources(params=PaginatedRequestParams(cursor=cursor)) all_resources.extend(result.resources) print(f"Fetched {len(result.resources)} resources") @@ -1993,12 +2241,12 @@ Run from the repository root: import asyncio from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client async def main(): # Connect to a streamable HTTP server - async with streamablehttp_client("http://localhost:8000/mcp") as ( + async with streamable_http_client("http://localhost:8000/mcp") as ( read_stream, write_stream, _, @@ -2122,11 +2370,12 @@ cd to the `examples/snippets` directory and run: import asyncio from urllib.parse import parse_qs, urlparse +import httpx from pydantic import AnyUrl from mcp import ClientSession from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -2180,15 +2429,16 @@ async def main(): callback_handler=handle_callback, ) - async with streamablehttp_client("http://localhost:8001/mcp", auth=oauth_auth) as (read, write, _): - async with ClientSession(read, write) as session: - await session.initialize() + async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client: + async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() - tools = await session.list_tools() - print(f"Available tools: {[tool.name for tool in tools.tools]}") + tools = await session.list_tools() + print(f"Available tools: {[tool.name for tool in tools.tools]}") - resources = await session.list_resources() - print(f"Available resources: {[r.uri for r in resources.resources]}") + resources = await session.list_resources() + print(f"Available resources: {[r.uri for r in resources.resources]}") def run(): @@ -2298,8 +2548,9 @@ MCP servers declare capabilities during initialization: ## Documentation - [API Reference](https://modelcontextprotocol.github.io/python-sdk/api/) +- [Experimental Features (Tasks)](https://modelcontextprotocol.github.io/python-sdk/experimental/tasks/) - [Model Context Protocol documentation](https://modelcontextprotocol.io) -- [Model Context Protocol specification](https://spec.modelcontextprotocol.io) +- [Model Context Protocol specification](https://modelcontextprotocol.io/specification/latest) - [Officially supported servers](https://github.com/modelcontextprotocol/servers) ## Contributing diff --git a/docs/experimental/index.md b/docs/experimental/index.md new file mode 100644 index 0000000000..1d496b3f10 --- /dev/null +++ b/docs/experimental/index.md @@ -0,0 +1,43 @@ +# Experimental Features + +!!! warning "Experimental APIs" + + The features in this section are experimental and may change without notice. + They track the evolving MCP specification and are not yet stable. + +This section documents experimental features in the MCP Python SDK. These features +implement draft specifications that are still being refined. + +## Available Experimental Features + +### [Tasks](tasks.md) + +Tasks enable asynchronous execution of MCP operations. Instead of waiting for a +long-running operation to complete, the server returns a task reference immediately. +Clients can then poll for status updates and retrieve results when ready. + +Tasks are useful for: + +- **Long-running computations** that would otherwise block +- **Batch operations** that process many items +- **Interactive workflows** that require user input (elicitation) or LLM assistance (sampling) + +## Using Experimental APIs + +Experimental features are accessed via the `.experimental` property: + +```python +# Server-side +@server.experimental.get_task() +async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + ... + +# Client-side +result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) +``` + +## Providing Feedback + +Since these features are experimental, feedback is especially valuable. If you encounter +issues or have suggestions, please open an issue on the +[python-sdk repository](https://github.com/modelcontextprotocol/python-sdk/issues). diff --git a/docs/experimental/tasks-client.md b/docs/experimental/tasks-client.md new file mode 100644 index 0000000000..cfd23e4e14 --- /dev/null +++ b/docs/experimental/tasks-client.md @@ -0,0 +1,361 @@ +# Client Task Usage + +!!! warning "Experimental" + + Tasks are an experimental feature. The API may change without notice. + +This guide covers calling task-augmented tools from clients, handling the `input_required` status, and advanced patterns like receiving task requests from servers. + +## Quick Start + +Call a tool as a task and poll for the result: + +```python +from mcp.client.session import ClientSession +from mcp.types import CallToolResult + +async with ClientSession(read, write) as session: + await session.initialize() + + # Call tool as task + result = await session.experimental.call_tool_as_task( + "process_data", + {"input": "hello"}, + ttl=60000, + ) + task_id = result.task.taskId + + # Poll until complete + async for status in session.experimental.poll_task(task_id): + print(f"Status: {status.status} - {status.statusMessage or ''}") + + # Get result + final = await session.experimental.get_task_result(task_id, CallToolResult) + print(f"Result: {final.content[0].text}") +``` + +## Calling Tools as Tasks + +Use `call_tool_as_task()` to invoke a tool with task augmentation: + +```python +result = await session.experimental.call_tool_as_task( + "my_tool", # Tool name + {"arg": "value"}, # Arguments + ttl=60000, # Time-to-live in milliseconds + meta={"key": "val"}, # Optional metadata +) + +task_id = result.task.taskId +print(f"Task: {task_id}, Status: {result.task.status}") +``` + +The response is a `CreateTaskResult` containing: + +- `task.taskId` - Unique identifier for polling +- `task.status` - Initial status (usually `"working"`) +- `task.pollInterval` - Suggested polling interval (milliseconds) +- `task.ttl` - Time-to-live for results +- `task.createdAt` - Creation timestamp + +## Polling with poll_task + +The `poll_task()` async iterator polls until the task reaches a terminal state: + +```python +async for status in session.experimental.poll_task(task_id): + print(f"Status: {status.status}") + if status.statusMessage: + print(f"Progress: {status.statusMessage}") +``` + +It automatically: + +- Respects the server's suggested `pollInterval` +- Stops when status is `completed`, `failed`, or `cancelled` +- Yields each status for progress display + +### Handling input_required + +When a task needs user input (elicitation), it transitions to `input_required`. You must call `get_task_result()` to receive and respond to the elicitation: + +```python +async for status in session.experimental.poll_task(task_id): + print(f"Status: {status.status}") + + if status.status == "input_required": + # This delivers the elicitation and waits for completion + final = await session.experimental.get_task_result(task_id, CallToolResult) + break +``` + +The elicitation callback (set during session creation) handles the actual user interaction. + +## Elicitation Callbacks + +To handle elicitation requests from the server, provide a callback when creating the session: + +```python +from mcp.types import ElicitRequestParams, ElicitResult + +async def handle_elicitation(context, params: ElicitRequestParams) -> ElicitResult: + # Display the message to the user + print(f"Server asks: {params.message}") + + # Collect user input (this is a simplified example) + response = input("Your response (y/n): ") + confirmed = response.lower() == "y" + + return ElicitResult( + action="http://github.com/accept", + content={"confirm": confirmed}, + ) + +async with ClientSession( + read, + write, + elicitation_callback=handle_elicitation, +) as session: + await session.initialize() + # ... call tasks that may require elicitation +``` + +## Sampling Callbacks + +Similarly, handle sampling requests with a callback: + +```python +from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent + +async def handle_sampling(context, params: CreateMessageRequestParams) -> CreateMessageResult: + # In a real implementation, call your LLM here + prompt = params.messages[-1].content.text if params.messages else "" + + # Return a mock response + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=f"Response to: {prompt}"), + model="my-model", + ) + +async with ClientSession( + read, + write, + sampling_callback=handle_sampling, +) as session: + # ... +``` + +## Retrieving Results + +Once a task completes, retrieve the result: + +```python +if status.status == "completed": + result = await session.experimental.get_task_result(task_id, CallToolResult) + for content in result.content: + if hasattr(content, "text"): + print(content.text) + +elif status.status == "failed": + print(f"Task failed: {status.statusMessage}") + +elif status.status == "cancelled": + print("Task was cancelled") +``` + +The result type matches the original request: + +- `tools/call` → `CallToolResult` +- `sampling/createMessage` → `CreateMessageResult` +- `elicitation/create` → `ElicitResult` + +## Cancellation + +Cancel a running task: + +```python +cancel_result = await session.experimental.cancel_task(task_id) +print(f"Cancelled, status: {cancel_result.status}") +``` + +Note: Cancellation is cooperative—the server must check for and handle cancellation. + +## Listing Tasks + +View all tasks on the server: + +```python +result = await session.experimental.list_tasks() +for task in result.tasks: + print(f"{task.taskId}: {task.status}") + +# Handle pagination +while result.nextCursor: + result = await session.experimental.list_tasks(cursor=result.nextCursor) + for task in result.tasks: + print(f"{task.taskId}: {task.status}") +``` + +## Advanced: Client as Task Receiver + +Servers can send task-augmented requests to clients. This is useful when the server needs the client to perform async work (like complex sampling or user interaction). + +### Declaring Client Capabilities + +Register task handlers to declare what task-augmented requests your client accepts: + +```python +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.types import ( + CreateTaskResult, GetTaskResult, GetTaskPayloadResult, + TaskMetadata, ElicitRequestParams, +) +from mcp.shared.experimental.tasks import InMemoryTaskStore + +# Client-side task store +client_store = InMemoryTaskStore() + +async def handle_augmented_elicitation(context, params: ElicitRequestParams, task_metadata: TaskMetadata): + """Handle task-augmented elicitation from server.""" + # Create a task for this elicitation + task = await client_store.create_task(task_metadata) + + # Start async work (e.g., show UI, wait for user) + async def complete_elicitation(): + # ... do async work ... + result = ElicitResult(action="http://github.com/accept", content={"confirm": True}) + await client_store.store_result(task.taskId, result) + await client_store.update_task(task.taskId, status="completed") + + context.session._task_group.start_soon(complete_elicitation) + + # Return task reference immediately + return CreateTaskResult(task=task) + +async def handle_get_task(context, params): + """Handle tasks/get from server.""" + task = await client_store.get_task(params.taskId) + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=100, + ) + +async def handle_get_task_result(context, params): + """Handle tasks/result from server.""" + result = await client_store.get_result(params.taskId) + return GetTaskPayloadResult.model_validate(result.model_dump()) + +task_handlers = ExperimentalTaskHandlers( + augmented_elicitation=handle_augmented_elicitation, + get_task=handle_get_task, + get_task_result=handle_get_task_result, +) + +async with ClientSession( + read, + write, + experimental_task_handlers=task_handlers, +) as session: + # Client now accepts task-augmented elicitation from server + await session.initialize() +``` + +This enables flows where: + +1. Client calls a task-augmented tool +2. Server's tool work calls `task.elicit_as_task()` +3. Client receives task-augmented elicitation +4. Client creates its own task, does async work +5. Server polls client's task +6. Eventually both tasks complete + +## Complete Example + +A client that handles all task scenarios: + +```python +import anyio +from mcp.client.session import ClientSession +from mcp.client.stdio import stdio_client +from mcp.types import CallToolResult, ElicitRequestParams, ElicitResult + + +async def elicitation_callback(context, params: ElicitRequestParams) -> ElicitResult: + print(f"\n[Elicitation] {params.message}") + response = input("Confirm? (y/n): ") + return ElicitResult(action="http://github.com/accept", content={"confirm": response.lower() == "y"}) + + +async def main(): + async with stdio_client(command="python", args=["server.py"]) as (read, write): + async with ClientSession( + read, + write, + elicitation_callback=elicitation_callback, + ) as session: + await session.initialize() + + # List available tools + tools = await session.list_tools() + print("Tools:", [t.name for t in tools.tools]) + + # Call a task-augmented tool + print("\nCalling task tool...") + result = await session.experimental.call_tool_as_task( + "confirm_action", + {"action": "delete files"}, + ) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # Poll and handle input_required + async for status in session.experimental.poll_task(task_id): + print(f"Status: {status.status}") + + if status.status == "input_required": + final = await session.experimental.get_task_result(task_id, CallToolResult) + print(f"Result: {final.content[0].text}") + break + + if status.status == "completed": + final = await session.experimental.get_task_result(task_id, CallToolResult) + print(f"Result: {final.content[0].text}") + + +if __name__ == "__main__": + anyio.run(main) +``` + +## Error Handling + +Handle task errors gracefully: + +```python +from mcp.shared.exceptions import McpError + +try: + result = await session.experimental.call_tool_as_task("my_tool", args) + task_id = result.task.taskId + + async for status in session.experimental.poll_task(task_id): + if status.status == "failed": + raise RuntimeError(f"Task failed: {status.statusMessage}") + + final = await session.experimental.get_task_result(task_id, CallToolResult) + +except McpError as e: + print(f"MCP error: {e.error.message}") +except Exception as e: + print(f"Error: {e}") +``` + +## Next Steps + +- [Server Implementation](tasks-server.md) - Build task-supporting servers +- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md new file mode 100644 index 0000000000..761dc5de5c --- /dev/null +++ b/docs/experimental/tasks-server.md @@ -0,0 +1,597 @@ +# Server Task Implementation + +!!! warning "Experimental" + + Tasks are an experimental feature. The API may change without notice. + +This guide covers implementing task support in MCP servers, from basic setup to advanced patterns like elicitation and sampling within tasks. + +## Quick Start + +The simplest way to add task support: + +```python +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED + +server = Server("my-server") +server.experimental.enable_tasks() # Registers all task handlers automatically + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="process_data", + description="Process data asynchronously", + inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + +@server.call_tool() +async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: + if name == "process_data": + return await handle_process_data(arguments) + return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) + +async def handle_process_data(arguments: dict) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Processing...") + result = arguments.get("input", "").upper() + return CallToolResult(content=[TextContent(type="text", text=result)]) + + return await ctx.experimental.run_task(work) +``` + +That's it. `enable_tasks()` automatically: + +- Creates an in-memory task store +- Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` +- Updates server capabilities + +## Tool Declaration + +Tools declare task support via the `execution.taskSupport` field: + +```python +from mcp.types import Tool, ToolExecution, TASK_REQUIRED, TASK_OPTIONAL, TASK_FORBIDDEN + +Tool( + name="my_tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), # or TASK_OPTIONAL, TASK_FORBIDDEN +) +``` + +| Value | Meaning | +|-------|---------| +| `TASK_REQUIRED` | Tool **must** be called as a task | +| `TASK_OPTIONAL` | Tool supports both sync and task execution | +| `TASK_FORBIDDEN` | Tool **cannot** be called as a task (default) | + +Validate the request matches your tool's requirements: + +```python +@server.call_tool() +async def handle_tool(name: str, arguments: dict): + ctx = server.request_context + + if name == "required_task_tool": + ctx.experimental.validate_task_mode(TASK_REQUIRED) # Raises if not task mode + return await handle_as_task(arguments) + + elif name == "optional_task_tool": + if ctx.experimental.is_task: + return await handle_as_task(arguments) + else: + return handle_sync(arguments) +``` + +## The run_task Pattern + +`run_task()` is the recommended way to execute task work: + +```python +async def handle_my_tool(arguments: dict) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + # Your work here + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + return await ctx.experimental.run_task(work) +``` + +**What `run_task()` does:** + +1. Creates a task in the store +2. Spawns your work function in the background +3. Returns `CreateTaskResult` immediately +4. Auto-completes the task when your function returns +5. Auto-fails the task if your function raises + +**The `ServerTaskContext` provides:** + +- `task.task_id` - The task identifier +- `task.update_status(message)` - Update progress +- `task.complete(result)` - Explicitly complete (usually automatic) +- `task.fail(error)` - Explicitly fail +- `task.is_cancelled` - Check if cancellation requested + +## Status Updates + +Keep clients informed of progress: + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Starting...") + + for i, item in enumerate(items): + await task.update_status(f"Processing {i+1}/{len(items)}") + await process_item(item) + + await task.update_status("Finalizing...") + return CallToolResult(content=[TextContent(type="text", text="Complete")]) +``` + +Status messages appear in `tasks/get` responses, letting clients show progress to users. + +## Elicitation Within Tasks + +Tasks can request user input via elicitation. This transitions the task to `input_required` status. + +### Form Elicitation + +Collect structured data from the user: + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Waiting for confirmation...") + + result = await task.elicit( + message="Delete these files?", + requestedSchema={ + "type": "object", + "properties": { + "confirm": {"type": "boolean"}, + "reason": {"type": "string"}, + }, + "required": ["confirm"], + }, + ) + + if result.action == "accept" and result.content.get("confirm"): + # User confirmed + return CallToolResult(content=[TextContent(type="text", text="Files deleted")]) + else: + # User declined or cancelled + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) +``` + +### URL Elicitation + +Direct users to external URLs for OAuth, payments, or other out-of-band flows: + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Waiting for OAuth...") + + result = await task.elicit_url( + message="Please authorize with GitHub", + url="https://github.com/login/oauth/authorize?client_id=...", + elicitation_id="oauth-github-123", + ) + + if result.action == "accept": + # User completed OAuth flow + return CallToolResult(content=[TextContent(type="text", text="Connected to GitHub")]) + else: + return CallToolResult(content=[TextContent(type="text", text="OAuth cancelled")]) +``` + +## Sampling Within Tasks + +Tasks can request LLM completions from the client: + +```python +from mcp.types import SamplingMessage, TextContent + +async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Generating response...") + + result = await task.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text="Write a haiku about coding"), + ) + ], + max_tokens=100, + ) + + haiku = result.content.text if isinstance(result.content, TextContent) else "Error" + return CallToolResult(content=[TextContent(type="text", text=haiku)]) +``` + +Sampling supports additional parameters: + +```python +result = await task.create_message( + messages=[...], + max_tokens=500, + system_prompt="You are a helpful assistant", + temperature=0.7, + stop_sequences=["\n\n"], + model_preferences=ModelPreferences(hints=[ModelHint(name="claude-3")]), +) +``` + +## Cancellation Support + +Check for cancellation in long-running work: + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + for i in range(1000): + if task.is_cancelled: + # Clean up and exit + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) + + await task.update_status(f"Step {i}/1000") + await process_step(i) + + return CallToolResult(content=[TextContent(type="text", text="Complete")]) +``` + +The SDK's default cancel handler updates the task status. Your work function should check `is_cancelled` periodically. + +## Custom Task Store + +For production, implement `TaskStore` with persistent storage: + +```python +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import Task, TaskMetadata, Result + +class RedisTaskStore(TaskStore): + def __init__(self, redis_client): + self.redis = redis_client + + async def create_task(self, metadata: TaskMetadata, task_id: str | None = None) -> Task: + # Create and persist task + ... + + async def get_task(self, task_id: str) -> Task | None: + # Retrieve task from Redis + ... + + async def update_task(self, task_id: str, status: str | None = None, ...) -> Task: + # Update and persist + ... + + async def store_result(self, task_id: str, result: Result) -> None: + # Store result in Redis + ... + + async def get_result(self, task_id: str) -> Result | None: + # Retrieve result + ... + + # ... implement remaining methods +``` + +Use your custom store: + +```python +store = RedisTaskStore(redis_client) +server.experimental.enable_tasks(store=store) +``` + +## Complete Example + +A server with multiple task-supporting tools: + +```python +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.types import ( + CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, + SamplingMessage, TASK_REQUIRED, +) + +server = Server("task-demo") +server.experimental.enable_tasks() + + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="confirm_action", + description="Requires user confirmation", + inputSchema={"type": "object", "properties": {"action": {"type": "string"}}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + Tool( + name="generate_text", + description="Generate text via LLM", + inputSchema={"type": "object", "properties": {"prompt": {"type": "string"}}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + ] + + +async def handle_confirm_action(arguments: dict) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + action = arguments.get("action", "unknown action") + + async def work(task: ServerTaskContext) -> CallToolResult: + result = await task.elicit( + message=f"Confirm: {action}?", + requestedSchema={ + "type": "object", + "properties": {"confirm": {"type": "boolean"}}, + "required": ["confirm"], + }, + ) + + if result.action == "accept" and result.content.get("confirm"): + return CallToolResult(content=[TextContent(type="text", text=f"Executed: {action}")]) + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) + + return await ctx.experimental.run_task(work) + + +async def handle_generate_text(arguments: dict) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + prompt = arguments.get("prompt", "Hello") + + async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Generating...") + + result = await task.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], + max_tokens=200, + ) + + text = result.content.text if isinstance(result.content, TextContent) else "Error" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + return await ctx.experimental.run_task(work) + + +@server.call_tool() +async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: + if name == "confirm_action": + return await handle_confirm_action(arguments) + elif name == "generate_text": + return await handle_generate_text(arguments) + return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) +``` + +## Error Handling in Tasks + +Tasks handle errors automatically, but you can also fail explicitly: + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + try: + result = await risky_operation() + return CallToolResult(content=[TextContent(type="text", text=result)]) + except PermissionError: + await task.fail("Access denied - insufficient permissions") + raise + except TimeoutError: + await task.fail("Operation timed out after 30 seconds") + raise +``` + +When `run_task()` catches an exception, it automatically: + +1. Marks the task as `failed` +2. Sets `statusMessage` to the exception message +3. Propagates the exception (which is caught by the task group) + +For custom error messages, call `task.fail()` before raising. + +## HTTP Transport Example + +For web applications, use the Streamable HTTP transport: + +```python +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import uvicorn +from starlette.applications import Starlette +from starlette.routing import Mount + +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import ( + CallToolResult, CreateTaskResult, TextContent, Tool, ToolExecution, TASK_REQUIRED, +) + + +server = Server("http-task-server") +server.experimental.enable_tasks() + + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="long_operation", + description="A long-running operation", + inputSchema={"type": "object", "properties": {"duration": {"type": "number"}}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + +async def handle_long_operation(arguments: dict) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + duration = arguments.get("duration", 5) + + async def work(task: ServerTaskContext) -> CallToolResult: + import anyio + for i in range(int(duration)): + await task.update_status(f"Step {i+1}/{int(duration)}") + await anyio.sleep(1) + return CallToolResult(content=[TextContent(type="text", text=f"Completed after {duration}s")]) + + return await ctx.experimental.run_task(work) + + +@server.call_tool() +async def handle_tool(name: str, arguments: dict) -> CallToolResult | CreateTaskResult: + if name == "long_operation": + return await handle_long_operation(arguments) + return CallToolResult(content=[TextContent(type="text", text=f"Unknown: {name}")], isError=True) + + +def create_app(): + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + lifespan=lifespan, + ) + + +if __name__ == "__main__": + uvicorn.run(create_app(), host="127.0.0.1", port=8000) +``` + +## Testing Task Servers + +Test task functionality with the SDK's testing utilities: + +```python +import pytest +import anyio +from mcp.client.session import ClientSession +from mcp.types import CallToolResult + + +@pytest.mark.anyio +async def test_task_tool(): + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream(10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream(10) + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def run_client(): + async with ClientSession(server_to_client_receive, client_to_server_send) as session: + await session.initialize() + + # Call the tool as a task + result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) + task_id = result.task.taskId + assert result.task.status == "working" + + # Poll until complete + async for status in session.experimental.poll_task(task_id): + if status.status in ("completed", "failed"): + break + + # Get result + final = await session.experimental.get_task_result(task_id, CallToolResult) + assert len(final.content) > 0 + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) +``` + +## Best Practices + +### Keep Work Functions Focused + +```python +# Good: focused work function +async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Validating...") + validate_input(arguments) + + await task.update_status("Processing...") + result = await process_data(arguments) + + return CallToolResult(content=[TextContent(type="text", text=result)]) +``` + +### Check Cancellation in Loops + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + results = [] + for item in large_dataset: + if task.is_cancelled: + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) + + results.append(await process(item)) + + return CallToolResult(content=[TextContent(type="text", text=str(results))]) +``` + +### Use Meaningful Status Messages + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Connecting to database...") + db = await connect() + + await task.update_status("Fetching records (0/1000)...") + for i, record in enumerate(records): + if i % 100 == 0: + await task.update_status(f"Processing records ({i}/1000)...") + await process(record) + + await task.update_status("Finalizing results...") + return CallToolResult(content=[TextContent(type="text", text="Done")]) +``` + +### Handle Elicitation Responses + +```python +async def work(task: ServerTaskContext) -> CallToolResult: + result = await task.elicit(message="Continue?", requestedSchema={...}) + + match result.action: + case "accept": + # User accepted, process content + return await process_accepted(result.content) + case "decline": + # User explicitly declined + return CallToolResult(content=[TextContent(type="text", text="User declined")]) + case "cancel": + # User cancelled the elicitation + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) +``` + +## Next Steps + +- [Client Usage](tasks-client.md) - Learn how clients interact with task servers +- [Tasks Overview](tasks.md) - Review lifecycle and concepts diff --git a/docs/experimental/tasks.md b/docs/experimental/tasks.md new file mode 100644 index 0000000000..2d4d06a025 --- /dev/null +++ b/docs/experimental/tasks.md @@ -0,0 +1,188 @@ +# Tasks + +!!! warning "Experimental" + + Tasks are an experimental feature tracking the draft MCP specification. + The API may change without notice. + +Tasks enable asynchronous request handling in MCP. Instead of blocking until an operation completes, the receiver creates a task, returns immediately, and the requestor polls for the result. + +## When to Use Tasks + +Tasks are designed for operations that: + +- Take significant time (seconds to minutes) +- Need progress updates during execution +- Require user input mid-execution (elicitation, sampling) +- Should run without blocking the requestor + +Common use cases: + +- Long-running data processing +- Multi-step workflows with user confirmation +- LLM-powered operations requiring sampling +- OAuth flows requiring user browser interaction + +## Task Lifecycle + +```text + ┌─────────────┐ + │ working │ + └──────┬──────┘ + │ + ┌────────────┼────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌────────────┐ ┌───────────┐ ┌───────────┐ + │ completed │ │ failed │ │ cancelled │ + └────────────┘ └───────────┘ └───────────┘ + ▲ + │ + ┌────────┴────────┐ + │ input_required │◄──────┐ + └────────┬────────┘ │ + │ │ + └────────────────┘ +``` + +| Status | Description | +|--------|-------------| +| `working` | Task is being processed | +| `input_required` | Receiver needs input from requestor (elicitation/sampling) | +| `completed` | Task finished successfully | +| `failed` | Task encountered an error | +| `cancelled` | Task was cancelled by requestor | + +Terminal states (`completed`, `failed`, `cancelled`) are final—tasks cannot transition out of them. + +## Bidirectional Flow + +Tasks work in both directions: + +**Client → Server** (most common): + +```text +Client Server + │ │ + │── tools/call (task) ──────────────>│ Creates task + │<── CreateTaskResult ───────────────│ + │ │ + │── tasks/get ──────────────────────>│ + │<── status: working ────────────────│ + │ │ ... work continues ... + │── tasks/get ──────────────────────>│ + │<── status: completed ──────────────│ + │ │ + │── tasks/result ───────────────────>│ + │<── CallToolResult ─────────────────│ +``` + +**Server → Client** (for elicitation/sampling): + +```text +Server Client + │ │ + │── elicitation/create (task) ──────>│ Creates task + │<── CreateTaskResult ───────────────│ + │ │ + │── tasks/get ──────────────────────>│ + │<── status: working ────────────────│ + │ │ ... user interaction ... + │── tasks/get ──────────────────────>│ + │<── status: completed ──────────────│ + │ │ + │── tasks/result ───────────────────>│ + │<── ElicitResult ───────────────────│ +``` + +## Key Concepts + +### Task Metadata + +When augmenting a request with task execution, include `TaskMetadata`: + +```python +from mcp.types import TaskMetadata + +task = TaskMetadata(ttl=60000) # TTL in milliseconds +``` + +The `ttl` (time-to-live) specifies how long the task and result are retained after completion. + +### Task Store + +Servers persist task state in a `TaskStore`. The SDK provides `InMemoryTaskStore` for development: + +```python +from mcp.shared.experimental.tasks import InMemoryTaskStore + +store = InMemoryTaskStore() +``` + +For production, implement `TaskStore` with a database or distributed cache. + +### Capabilities + +Both servers and clients declare task support through capabilities: + +**Server capabilities:** + +- `tasks.requests.tools.call` - Server accepts task-augmented tool calls + +**Client capabilities:** + +- `tasks.requests.sampling.createMessage` - Client accepts task-augmented sampling +- `tasks.requests.elicitation.create` - Client accepts task-augmented elicitation + +The SDK manages these automatically when you enable task support. + +## Quick Example + +**Server** (simplified API): + +```python +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.types import CallToolResult, TextContent, TASK_REQUIRED + +server = Server("my-server") +server.experimental.enable_tasks() # One-line setup + +@server.call_tool() +async def handle_tool(name: str, arguments: dict): + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext): + await task.update_status("Processing...") + # ... do work ... + return CallToolResult(content=[TextContent(type="text", text="Done!")]) + + return await ctx.experimental.run_task(work) +``` + +**Client:** + +```python +from mcp.client.session import ClientSession +from mcp.types import CallToolResult + +async with ClientSession(read, write) as session: + await session.initialize() + + # Call tool as task + result = await session.experimental.call_tool_as_task("my_tool", {"arg": "value"}) + task_id = result.task.taskId + + # Poll until done + async for status in session.experimental.poll_task(task_id): + print(f"Status: {status.status}") + + # Get result + final = await session.experimental.get_task_result(task_id, CallToolResult) +``` + +## Next Steps + +- [Server Implementation](tasks-server.md) - Build task-supporting servers +- [Client Usage](tasks-client.md) - Call and poll tasks from clients diff --git a/docs/index.md b/docs/index.md index 139afca4aa..eb5ddf4000 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,7 +17,7 @@ Here's a simple MCP server that exposes a tool, resource, and prompt: ```python title="server.py" from mcp.server.fastmcp import FastMCP -mcp = FastMCP("Test Server") +mcp = FastMCP("Test Server", json_response=True) @mcp.tool() @@ -36,12 +36,22 @@ def get_greeting(name: str) -> str: def greet_user(name: str, style: str = "friendly") -> str: """Generate a greeting prompt""" return f"Write a {style} greeting for someone named {name}." + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") +``` + +Run the server: + +```bash +uv run --with mcp server.py ``` -Test it with the [MCP Inspector](https://github.com/modelcontextprotocol/inspector): +Then open the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) and connect to `http://localhost:8000/mcp`: ```bash -uv run mcp dev server.py +npx -y @modelcontextprotocol/inspector ``` ## Getting Started diff --git a/examples/clients/conformance-auth-client/README.md b/examples/clients/conformance-auth-client/README.md new file mode 100644 index 0000000000..312a992d0a --- /dev/null +++ b/examples/clients/conformance-auth-client/README.md @@ -0,0 +1,49 @@ +# MCP Conformance Auth Client + +A Python OAuth client designed for use with the MCP conformance test framework. + +## Overview + +This client implements OAuth authentication for MCP and is designed to work automatically with the conformance test framework without requiring user interaction. It programmatically fetches authorization URLs and extracts auth codes from redirects. + +## Installation + +```bash +cd examples/clients/conformance-auth-client +uv sync +``` + +## Usage with Conformance Tests + +Run the auth conformance tests against this Python client: + +```bash +# From the conformance repository +npx @modelcontextprotocol/conformance client \ + --command "uv run --directory /path/to/python-sdk/examples/clients/conformance-auth-client python -m mcp_conformance_auth_client" \ + --scenario auth/basic-dcr +``` + +Available auth test scenarios: + +- `auth/basic-dcr` - Tests OAuth Dynamic Client Registration flow +- `auth/basic-metadata-var1` - Tests OAuth with authorization metadata + +## How It Works + +Unlike interactive OAuth clients that open a browser for user authentication, this client: + +1. Receives the authorization URL from the OAuth provider +2. Makes an HTTP request to that URL directly (without following redirects) +3. Extracts the authorization code from the redirect response +4. Uses the code to complete the OAuth token exchange + +This allows the conformance test framework's mock OAuth server to automatically provide auth codes without human interaction. + +## Direct Usage + +You can also run the client directly: + +```bash +uv run python -m mcp_conformance_auth_client http://localhost:3000/mcp +``` diff --git a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py new file mode 100644 index 0000000000..ba8679e3ac --- /dev/null +++ b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +""" +MCP OAuth conformance test client. + +This client is designed to work with the MCP conformance test framework. +It automatically handles OAuth flows without user interaction by programmatically +fetching the authorization URL and extracting the auth code from the redirect. + +Usage: + python -m mcp_conformance_auth_client + +Environment Variables: + MCP_CONFORMANCE_CONTEXT - JSON object containing test credentials: + { + "client_id": "...", + "client_secret": "...", # For client_secret_basic flow + "private_key_pem": "...", # For private_key_jwt flow + "signing_algorithm": "ES256" # Optional, defaults to ES256 + } + +Scenarios: + auth/* - Authorization code flow scenarios (default behavior) + auth/client-credentials-jwt - Client credentials with JWT authentication (SEP-1046) + auth/client-credentials-basic - Client credentials with client_secret_basic +""" + +import asyncio +import json +import logging +import os +import sys +from urllib.parse import ParseResult, parse_qs, urlparse + +import httpx +from mcp import ClientSession +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.auth.extensions.client_credentials import ( + ClientCredentialsOAuthProvider, + PrivateKeyJWTOAuthProvider, + SignedJWTParameters, +) +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from pydantic import AnyUrl + + +def get_conformance_context() -> dict: + """Load conformance test context from MCP_CONFORMANCE_CONTEXT environment variable.""" + context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT") + if not context_json: + raise RuntimeError( + "MCP_CONFORMANCE_CONTEXT environment variable not set. " + "Expected JSON with client_id, client_secret, and/or private_key_pem." + ) + try: + return json.loads(context_json) + except json.JSONDecodeError as e: + raise RuntimeError(f"Failed to parse MCP_CONFORMANCE_CONTEXT as JSON: {e}") from e + + +# Set up logging to stderr (stdout is for conformance test output) +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + stream=sys.stderr, +) +logger = logging.getLogger(__name__) + + +class InMemoryTokenStorage(TokenStorage): + """Simple in-memory token storage for conformance testing.""" + + def __init__(self): + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +class ConformanceOAuthCallbackHandler: + """ + OAuth callback handler that automatically fetches the authorization URL + and extracts the auth code, without requiring user interaction. + + This mimics the behavior of the TypeScript ConformanceOAuthProvider. + """ + + def __init__(self): + self._auth_code: str | None = None + self._state: str | None = None + + async def handle_redirect(self, authorization_url: str) -> None: + """ + Fetch the authorization URL and extract the auth code from the redirect. + + The conformance test server returns a redirect with the auth code, + so we can capture it programmatically. + """ + logger.debug(f"Fetching authorization URL: {authorization_url}") + + async with httpx.AsyncClient() as client: + response = await client.get( + authorization_url, + follow_redirects=False, # Don't follow redirects automatically + ) + + # Check for redirect response + if response.status_code in (301, 302, 303, 307, 308): + location = response.headers.get("location") + if location: + redirect_url: ParseResult = urlparse(location) + query_params: dict[str, list[str]] = parse_qs(redirect_url.query) + + if "code" in query_params: + self._auth_code = query_params["code"][0] + state_values = query_params.get("state") + self._state = state_values[0] if state_values else None + logger.debug(f"Got auth code from redirect: {self._auth_code[:10]}...") + return + else: + raise RuntimeError(f"No auth code in redirect URL: {location}") + else: + raise RuntimeError(f"No redirect location received from {authorization_url}") + else: + raise RuntimeError(f"Expected redirect response, got {response.status_code} from {authorization_url}") + + async def handle_callback(self) -> tuple[str, str | None]: + """Return the captured auth code and state, then clear them for potential reuse.""" + if self._auth_code is None: + raise RuntimeError("No authorization code available - was handle_redirect called?") + auth_code = self._auth_code + state = self._state + # Clear the stored values so the next auth flow gets fresh ones + self._auth_code = None + self._state = None + return auth_code, state + + +async def run_authorization_code_client(server_url: str) -> None: + """ + Run the conformance test client with authorization code flow. + + This function: + 1. Connects to the MCP server with OAuth authorization code flow + 2. Initializes the session + 3. Lists available tools + 4. Calls a test tool + """ + logger.debug(f"Starting conformance auth client (authorization_code) for {server_url}") + + # Create callback handler that will automatically fetch auth codes + callback_handler = ConformanceOAuthCallbackHandler() + + # Create OAuth authentication handler + oauth_auth = OAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="conformance-auth-client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ), + storage=InMemoryTokenStorage(), + redirect_handler=callback_handler.handle_redirect, + callback_handler=callback_handler.handle_callback, + ) + + await _run_session(server_url, oauth_auth) + + +async def run_client_credentials_jwt_client(server_url: str) -> None: + """ + Run the conformance test client with client credentials flow using private_key_jwt (SEP-1046). + + This function: + 1. Connects to the MCP server with OAuth client_credentials grant + 2. Uses private_key_jwt authentication with credentials from MCP_CONFORMANCE_CONTEXT + 3. Initializes the session + 4. Lists available tools + 5. Calls a test tool + """ + logger.debug(f"Starting conformance auth client (client_credentials_jwt) for {server_url}") + + # Load credentials from environment + context = get_conformance_context() + client_id = context.get("client_id") + private_key_pem = context.get("private_key_pem") + signing_algorithm = context.get("signing_algorithm", "ES256") + + if not client_id: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_id'") + if not private_key_pem: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'private_key_pem'") + + # Create JWT parameters for SDK-signed assertions + jwt_params = SignedJWTParameters( + issuer=client_id, + subject=client_id, + signing_algorithm=signing_algorithm, + signing_key=private_key_pem, + ) + + # Create OAuth provider for client_credentials with private_key_jwt + oauth_auth = PrivateKeyJWTOAuthProvider( + server_url=server_url, + storage=InMemoryTokenStorage(), + client_id=client_id, + assertion_provider=jwt_params.create_assertion_provider(), + ) + + await _run_session(server_url, oauth_auth) + + +async def run_client_credentials_basic_client(server_url: str) -> None: + """ + Run the conformance test client with client credentials flow using client_secret_basic. + + This function: + 1. Connects to the MCP server with OAuth client_credentials grant + 2. Uses client_secret_basic authentication with credentials from MCP_CONFORMANCE_CONTEXT + 3. Initializes the session + 4. Lists available tools + 5. Calls a test tool + """ + logger.debug(f"Starting conformance auth client (client_credentials_basic) for {server_url}") + + # Load credentials from environment + context = get_conformance_context() + client_id = context.get("client_id") + client_secret = context.get("client_secret") + + if not client_id: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_id'") + if not client_secret: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_secret'") + + # Create OAuth provider for client_credentials with client_secret_basic + oauth_auth = ClientCredentialsOAuthProvider( + server_url=server_url, + storage=InMemoryTokenStorage(), + client_id=client_id, + client_secret=client_secret, + token_endpoint_auth_method="client_secret_basic", + ) + + await _run_session(server_url, oauth_auth) + + +async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: + """Common session logic for all OAuth flows.""" + # Connect using streamable HTTP transport with OAuth + async with streamablehttp_client( + url=server_url, + auth=oauth_auth, + timeout=30.0, + sse_read_timeout=60.0, + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + await session.initialize() + logger.debug("Successfully connected and initialized MCP session") + + # List tools + tools_result = await session.list_tools() + logger.debug(f"Listed tools: {[t.name for t in tools_result.tools]}") + + # Call test tool (expected by conformance tests) + try: + result = await session.call_tool("test-tool", {}) + logger.debug(f"Called test-tool, result: {result}") + except Exception as e: + logger.debug(f"Tool call result/error: {e}") + + logger.debug("Connection closed successfully") + + +def main() -> None: + """Main entry point for the conformance auth client.""" + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + print("", file=sys.stderr) + print("Scenarios:", file=sys.stderr) + print(" auth/* - Authorization code flow (default)", file=sys.stderr) + print(" auth/client-credentials-jwt - Client credentials with JWT auth (SEP-1046)", file=sys.stderr) + print(" auth/client-credentials-basic - Client credentials with client_secret_basic", file=sys.stderr) + sys.exit(1) + + scenario = sys.argv[1] + server_url = sys.argv[2] + + try: + if scenario == "auth/client-credentials-jwt": + asyncio.run(run_client_credentials_jwt_client(server_url)) + elif scenario == "auth/client-credentials-basic": + asyncio.run(run_client_credentials_basic_client(server_url)) + else: + # Default to authorization code flow for all other auth/* scenarios + asyncio.run(run_authorization_code_client(server_url)) + except Exception: + logger.exception("Client failed") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__main__.py b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__main__.py new file mode 100644 index 0000000000..1b8f8acb09 --- /dev/null +++ b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__main__.py @@ -0,0 +1,6 @@ +"""Allow running the module with python -m.""" + +from . import main + +if __name__ == "__main__": + main() diff --git a/examples/clients/conformance-auth-client/pyproject.toml b/examples/clients/conformance-auth-client/pyproject.toml new file mode 100644 index 0000000000..3d03b4d4a1 --- /dev/null +++ b/examples/clients/conformance-auth-client/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-conformance-auth-client" +version = "0.1.0" +description = "OAuth conformance test client for MCP" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic" }] +keywords = ["mcp", "oauth", "client", "auth", "conformance", "testing"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["mcp", "httpx>=0.28.1"] + +[project.scripts] +mcp-conformance-auth-client = "mcp_conformance_auth_client:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_conformance_auth_client"] + +[tool.pyright] +include = ["mcp_conformance_auth_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 39c69501d1..0223b72394 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -11,15 +11,15 @@ import threading import time import webbrowser -from datetime import timedelta from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any from urllib.parse import parse_qs, urlparse +import httpx from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -150,9 +150,15 @@ def get_state(self): class SimpleAuthClient: """Simple MCP client with auth support.""" - def __init__(self, server_url: str, transport_type: str = "streamable-http"): + def __init__( + self, + server_url: str, + transport_type: str = "streamable-http", + client_metadata_url: str | None = None, + ): self.server_url = server_url self.transport_type = transport_type + self.client_metadata_url = client_metadata_url self.session: ClientSession | None = None async def connect(self): @@ -177,7 +183,6 @@ async def callback_handler() -> tuple[str, str | None]: "redirect_uris": ["http://localhost:3030/callback"], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], - "token_endpoint_auth_method": "client_secret_post", } async def _default_redirect_handler(authorization_url: str) -> None: @@ -186,12 +191,14 @@ async def _default_redirect_handler(authorization_url: str) -> None: webbrowser.open(authorization_url) # Create OAuth authentication handler using the new interface + # Use client_metadata_url to enable CIMD when the server supports it oauth_auth = OAuthClientProvider( server_url=self.server_url.replace("/mcp", ""), client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), storage=InMemoryTokenStorage(), redirect_handler=_default_redirect_handler, callback_handler=callback_handler, + client_metadata_url=self.client_metadata_url, ) # Create transport with auth handler based on transport type @@ -200,17 +207,17 @@ async def _default_redirect_handler(authorization_url: str) -> None: async with sse_client( url=self.server_url, auth=oauth_auth, - timeout=60, + timeout=60.0, ) as (read_stream, write_stream): await self._run_session(read_stream, write_stream, None) else: print("📡 Opening StreamableHTTP transport connection with auth...") - async with streamablehttp_client( - url=self.server_url, - auth=oauth_auth, - timeout=timedelta(seconds=60), - ) as (read_stream, write_stream, get_session_id): - await self._run_session(read_stream, write_stream, get_session_id) + async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client: + async with streamable_http_client( + url=self.server_url, + http_client=custom_client, + ) as (read_stream, write_stream, get_session_id): + await self._run_session(read_stream, write_stream, get_session_id) except Exception as e: print(f"❌ Failed to connect: {e}") @@ -335,6 +342,7 @@ async def main(): # Most MCP streamable HTTP servers use /mcp as the endpoint server_url = os.getenv("MCP_SERVER_PORT", 8000) transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable-http") + client_metadata_url = os.getenv("MCP_CLIENT_METADATA_URL") server_url = ( f"http://localhost:{server_url}/mcp" if transport_type == "streamable-http" @@ -344,9 +352,11 @@ async def main(): print("🚀 Simple MCP Auth Client") print(f"Connecting to: {server_url}") print(f"Transport type: {transport_type}") + if client_metadata_url: + print(f"Client metadata URL: {client_metadata_url}") # Start connection flow - OAuth will be handled automatically - client = SimpleAuthClient(server_url, transport_type) + client = SimpleAuthClient(server_url, transport_type, client_metadata_url) await client.connect() diff --git a/examples/clients/simple-auth-client/pyproject.toml b/examples/clients/simple-auth-client/pyproject.toml index 0c1021072c..46aba8dc12 100644 --- a/examples/clients/simple-auth-client/pyproject.toml +++ b/examples/clients/simple-auth-client/pyproject.toml @@ -14,10 +14,7 @@ classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", ] -dependencies = [ - "click>=8.2.0", - "mcp>=1.0.0", -] +dependencies = ["click>=8.2.0", "mcp"] [project.scripts] mcp-simple-auth-client = "mcp_simple_auth_client.main:cli" @@ -42,11 +39,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] - -[tool.uv.sources] -mcp = { path = "../../../" } - -[[tool.uv.index]] -url = "https://pypi.org/simple" +[dependency-groups] +dev = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/clients/simple-auth-client/uv.lock b/examples/clients/simple-auth-client/uv.lock deleted file mode 100644 index a62447fcbe..0000000000 --- a/examples/clients/simple-auth-client/uv.lock +++ /dev/null @@ -1,535 +0,0 @@ -version = 1 -requires-python = ">=3.10" - -[[package]] -name = "annotated-types" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, -] - -[[package]] -name = "anyio" -version = "4.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 }, -] - -[[package]] -name = "certifi" -version = "2025.4.26" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/9e/c05b3920a3b7d20d3d3310465f50348e5b3694f4f88c6daf736eef3024c4/certifi-2025.4.26.tar.gz", hash = "sha256:0a816057ea3cdefcef70270d2c515e4506bbc954f417fa5ade2021213bb8f0c6", size = 160705 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4a/7e/3db2bd1b1f9e95f7cddca6d6e75e2f2bd9f51b1246e546d88addca0106bd/certifi-2025.4.26-py3-none-any.whl", hash = "sha256:30350364dfe371162649852c63336a15c70c6510c2ad5015b21c2345311805f3", size = 159618 }, -] - -[[package]] -name = "click" -version = "8.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/cd/0f/62ca20172d4f87d93cf89665fbaedcd560ac48b465bd1d92bfc7ea6b0a41/click-8.2.0.tar.gz", hash = "sha256:f5452aeddd9988eefa20f90f05ab66f17fce1ee2a36907fd30b05bbb5953814d", size = 235857 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/58/1f37bf81e3c689cc74ffa42102fa8915b59085f54a6e4a80bc6265c0f6bf/click-8.2.0-py3-none-any.whl", hash = "sha256:6b303f0b2aa85f1cb4e5303078fadcbcd4e476f114fab9b5007005711839325c", size = 102156 }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, -] - -[[package]] -name = "exceptiongroup" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674 }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784 }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, -] - -[[package]] -name = "httpx-sse" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, -] - -[[package]] -name = "iniconfig" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, -] - -[[package]] -name = "mcp" -source = { directory = "../../../" } -dependencies = [ - { name = "anyio" }, - { name = "httpx" }, - { name = "httpx-sse" }, - { name = "pydantic" }, - { name = "pydantic-settings" }, - { name = "python-multipart" }, - { name = "sse-starlette" }, - { name = "starlette" }, - { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, -] - -[package.metadata] -requires-dist = [ - { name = "anyio", specifier = ">=4.5" }, - { name = "httpx", specifier = ">=0.27" }, - { name = "httpx-sse", specifier = ">=0.4" }, - { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, - { name = "pydantic-settings", specifier = ">=2.5.2" }, - { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, - { name = "python-multipart", specifier = ">=0.0.9" }, - { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, - { name = "sse-starlette", specifier = ">=1.6.1" }, - { name = "starlette", specifier = ">=0.27" }, - { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, - { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, - { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.391" }, - { name = "pytest", specifier = ">=8.3.4" }, - { name = "pytest-examples", specifier = ">=0.0.14" }, - { name = "pytest-flakefinder", specifier = ">=1.1.0" }, - { name = "pytest-pretty", specifier = ">=1.2.0" }, - { name = "pytest-xdist", specifier = ">=3.6.1" }, - { name = "ruff", specifier = ">=0.8.5" }, - { name = "trio", specifier = ">=0.26.2" }, -] -docs = [ - { name = "mkdocs", specifier = ">=1.6.1" }, - { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, - { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, - { name = "mkdocstrings-python", specifier = ">=1.12.2" }, -] - -[[package]] -name = "mcp-simple-auth-client" -version = "0.1.0" -source = { editable = "." } -dependencies = [ - { name = "click" }, - { name = "mcp" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "pytest" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "click", specifier = ">=8.0.0" }, - { name = "mcp", directory = "../../../" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.379" }, - { name = "pytest", specifier = ">=8.3.3" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "nodeenv" -version = "1.9.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469 }, -] - -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 }, -] - -[[package]] -name = "pydantic" -version = "2.11.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "annotated-types" }, - { name = "pydantic-core" }, - { name = "typing-extensions" }, - { name = "typing-inspection" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 }, -] - -[[package]] -name = "pydantic-core" -version = "2.33.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/92/b31726561b5dae176c2d2c2dc43a9c5bfba5d32f96f8b4c0a600dd492447/pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8", size = 2028817 }, - { url = "https://files.pythonhosted.org/packages/a3/44/3f0b95fafdaca04a483c4e685fe437c6891001bf3ce8b2fded82b9ea3aa1/pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d", size = 1861357 }, - { url = "https://files.pythonhosted.org/packages/30/97/e8f13b55766234caae05372826e8e4b3b96e7b248be3157f53237682e43c/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d", size = 1898011 }, - { url = "https://files.pythonhosted.org/packages/9b/a3/99c48cf7bafc991cc3ee66fd544c0aae8dc907b752f1dad2d79b1b5a471f/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572", size = 1982730 }, - { url = "https://files.pythonhosted.org/packages/de/8e/a5b882ec4307010a840fb8b58bd9bf65d1840c92eae7534c7441709bf54b/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02", size = 2136178 }, - { url = "https://files.pythonhosted.org/packages/e4/bb/71e35fc3ed05af6834e890edb75968e2802fe98778971ab5cba20a162315/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b", size = 2736462 }, - { url = "https://files.pythonhosted.org/packages/31/0d/c8f7593e6bc7066289bbc366f2235701dcbebcd1ff0ef8e64f6f239fb47d/pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2", size = 2005652 }, - { url = "https://files.pythonhosted.org/packages/d2/7a/996d8bd75f3eda405e3dd219ff5ff0a283cd8e34add39d8ef9157e722867/pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a", size = 2113306 }, - { url = "https://files.pythonhosted.org/packages/ff/84/daf2a6fb2db40ffda6578a7e8c5a6e9c8affb251a05c233ae37098118788/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac", size = 2073720 }, - { url = "https://files.pythonhosted.org/packages/77/fb/2258da019f4825128445ae79456a5499c032b55849dbd5bed78c95ccf163/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a", size = 2244915 }, - { url = "https://files.pythonhosted.org/packages/d8/7a/925ff73756031289468326e355b6fa8316960d0d65f8b5d6b3a3e7866de7/pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b", size = 2241884 }, - { url = "https://files.pythonhosted.org/packages/0b/b0/249ee6d2646f1cdadcb813805fe76265745c4010cf20a8eba7b0e639d9b2/pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22", size = 1910496 }, - { url = "https://files.pythonhosted.org/packages/66/ff/172ba8f12a42d4b552917aa65d1f2328990d3ccfc01d5b7c943ec084299f/pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640", size = 1955019 }, - { url = "https://files.pythonhosted.org/packages/3f/8d/71db63483d518cbbf290261a1fc2839d17ff89fce7089e08cad07ccfce67/pydantic_core-2.33.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4c5b0a576fb381edd6d27f0a85915c6daf2f8138dc5c267a57c08a62900758c7", size = 2028584 }, - { url = "https://files.pythonhosted.org/packages/24/2f/3cfa7244ae292dd850989f328722d2aef313f74ffc471184dc509e1e4e5a/pydantic_core-2.33.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e799c050df38a639db758c617ec771fd8fb7a5f8eaaa4b27b101f266b216a246", size = 1855071 }, - { url = "https://files.pythonhosted.org/packages/b3/d3/4ae42d33f5e3f50dd467761304be2fa0a9417fbf09735bc2cce003480f2a/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc46a01bf8d62f227d5ecee74178ffc448ff4e5197c756331f71efcc66dc980f", size = 1897823 }, - { url = "https://files.pythonhosted.org/packages/f4/f3/aa5976e8352b7695ff808599794b1fba2a9ae2ee954a3426855935799488/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a144d4f717285c6d9234a66778059f33a89096dfb9b39117663fd8413d582dcc", size = 1983792 }, - { url = "https://files.pythonhosted.org/packages/d5/7a/cda9b5a23c552037717f2b2a5257e9b2bfe45e687386df9591eff7b46d28/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cf6373c21bc80b2e0dc88444f41ae60b2f070ed02095754eb5a01df12256de", size = 2136338 }, - { url = "https://files.pythonhosted.org/packages/2b/9f/b8f9ec8dd1417eb9da784e91e1667d58a2a4a7b7b34cf4af765ef663a7e5/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dc625f4aa79713512d1976fe9f0bc99f706a9dee21dfd1810b4bbbf228d0e8a", size = 2730998 }, - { url = "https://files.pythonhosted.org/packages/47/bc/cd720e078576bdb8255d5032c5d63ee5c0bf4b7173dd955185a1d658c456/pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b21b5549499972441da4758d662aeea93f1923f953e9cbaff14b8b9565aef", size = 2003200 }, - { url = "https://files.pythonhosted.org/packages/ca/22/3602b895ee2cd29d11a2b349372446ae9727c32e78a94b3d588a40fdf187/pydantic_core-2.33.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bdc25f3681f7b78572699569514036afe3c243bc3059d3942624e936ec93450e", size = 2113890 }, - { url = "https://files.pythonhosted.org/packages/ff/e6/e3c5908c03cf00d629eb38393a98fccc38ee0ce8ecce32f69fc7d7b558a7/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fe5b32187cbc0c862ee201ad66c30cf218e5ed468ec8dc1cf49dec66e160cc4d", size = 2073359 }, - { url = "https://files.pythonhosted.org/packages/12/e7/6a36a07c59ebefc8777d1ffdaf5ae71b06b21952582e4b07eba88a421c79/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:bc7aee6f634a6f4a95676fcb5d6559a2c2a390330098dba5e5a5f28a2e4ada30", size = 2245883 }, - { url = "https://files.pythonhosted.org/packages/16/3f/59b3187aaa6cc0c1e6616e8045b284de2b6a87b027cce2ffcea073adf1d2/pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:235f45e5dbcccf6bd99f9f472858849f73d11120d76ea8707115415f8e5ebebf", size = 2241074 }, - { url = "https://files.pythonhosted.org/packages/e0/ed/55532bb88f674d5d8f67ab121a2a13c385df382de2a1677f30ad385f7438/pydantic_core-2.33.2-cp311-cp311-win32.whl", hash = "sha256:6368900c2d3ef09b69cb0b913f9f8263b03786e5b2a387706c5afb66800efd51", size = 1910538 }, - { url = "https://files.pythonhosted.org/packages/fe/1b/25b7cccd4519c0b23c2dd636ad39d381abf113085ce4f7bec2b0dc755eb1/pydantic_core-2.33.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e063337ef9e9820c77acc768546325ebe04ee38b08703244c1309cccc4f1bab", size = 1952909 }, - { url = "https://files.pythonhosted.org/packages/49/a9/d809358e49126438055884c4366a1f6227f0f84f635a9014e2deb9b9de54/pydantic_core-2.33.2-cp311-cp311-win_arm64.whl", hash = "sha256:6b99022f1d19bc32a4c2a0d544fc9a76e3be90f0b3f4af413f87d38749300e65", size = 1897786 }, - { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000 }, - { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996 }, - { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957 }, - { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199 }, - { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296 }, - { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109 }, - { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028 }, - { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044 }, - { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881 }, - { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034 }, - { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187 }, - { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628 }, - { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866 }, - { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894 }, - { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688 }, - { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808 }, - { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580 }, - { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859 }, - { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810 }, - { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498 }, - { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611 }, - { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924 }, - { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196 }, - { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389 }, - { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223 }, - { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473 }, - { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269 }, - { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921 }, - { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162 }, - { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560 }, - { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777 }, - { url = "https://files.pythonhosted.org/packages/30/68/373d55e58b7e83ce371691f6eaa7175e3a24b956c44628eb25d7da007917/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa", size = 2023982 }, - { url = "https://files.pythonhosted.org/packages/a4/16/145f54ac08c96a63d8ed6442f9dec17b2773d19920b627b18d4f10a061ea/pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29", size = 1858412 }, - { url = "https://files.pythonhosted.org/packages/41/b1/c6dc6c3e2de4516c0bb2c46f6a373b91b5660312342a0cf5826e38ad82fa/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d", size = 1892749 }, - { url = "https://files.pythonhosted.org/packages/12/73/8cd57e20afba760b21b742106f9dbdfa6697f1570b189c7457a1af4cd8a0/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e", size = 2067527 }, - { url = "https://files.pythonhosted.org/packages/e3/d5/0bb5d988cc019b3cba4a78f2d4b3854427fc47ee8ec8e9eaabf787da239c/pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c", size = 2108225 }, - { url = "https://files.pythonhosted.org/packages/f1/c5/00c02d1571913d496aabf146106ad8239dc132485ee22efe08085084ff7c/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec", size = 2069490 }, - { url = "https://files.pythonhosted.org/packages/22/a8/dccc38768274d3ed3a59b5d06f59ccb845778687652daa71df0cab4040d7/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052", size = 2237525 }, - { url = "https://files.pythonhosted.org/packages/d4/e7/4f98c0b125dda7cf7ccd14ba936218397b44f50a56dd8c16a3091df116c3/pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c", size = 2238446 }, - { url = "https://files.pythonhosted.org/packages/ce/91/2ec36480fdb0b783cd9ef6795753c1dea13882f2e68e73bce76ae8c21e6a/pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808", size = 2066678 }, - { url = "https://files.pythonhosted.org/packages/7b/27/d4ae6487d73948d6f20dddcd94be4ea43e74349b56eba82e9bdee2d7494c/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:dd14041875d09cc0f9308e37a6f8b65f5585cf2598a53aa0123df8b129d481f8", size = 2025200 }, - { url = "https://files.pythonhosted.org/packages/f1/b8/b3cb95375f05d33801024079b9392a5ab45267a63400bf1866e7ce0f0de4/pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d87c561733f66531dced0da6e864f44ebf89a8fba55f31407b00c2f7f9449593", size = 1859123 }, - { url = "https://files.pythonhosted.org/packages/05/bc/0d0b5adeda59a261cd30a1235a445bf55c7e46ae44aea28f7bd6ed46e091/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f82865531efd18d6e07a04a17331af02cb7a651583c418df8266f17a63c6612", size = 1892852 }, - { url = "https://files.pythonhosted.org/packages/3e/11/d37bdebbda2e449cb3f519f6ce950927b56d62f0b84fd9cb9e372a26a3d5/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bfb5112df54209d820d7bf9317c7a6c9025ea52e49f46b6a2060104bba37de7", size = 2067484 }, - { url = "https://files.pythonhosted.org/packages/8c/55/1f95f0a05ce72ecb02a8a8a1c3be0579bbc29b1d5ab68f1378b7bebc5057/pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64632ff9d614e5eecfb495796ad51b0ed98c453e447a76bcbeeb69615079fc7e", size = 2108896 }, - { url = "https://files.pythonhosted.org/packages/53/89/2b2de6c81fa131f423246a9109d7b2a375e83968ad0800d6e57d0574629b/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f889f7a40498cc077332c7ab6b4608d296d852182211787d4f3ee377aaae66e8", size = 2069475 }, - { url = "https://files.pythonhosted.org/packages/b8/e9/1f7efbe20d0b2b10f6718944b5d8ece9152390904f29a78e68d4e7961159/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:de4b83bb311557e439b9e186f733f6c645b9417c84e2eb8203f3f820a4b988bf", size = 2239013 }, - { url = "https://files.pythonhosted.org/packages/3c/b2/5309c905a93811524a49b4e031e9851a6b00ff0fb668794472ea7746b448/pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f68293f055f51b51ea42fafc74b6aad03e70e191799430b90c13d643059ebb", size = 2238715 }, - { url = "https://files.pythonhosted.org/packages/32/56/8a7ca5d2cd2cda1d245d34b1c9a942920a718082ae8e54e5f3e5a58b7add/pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1", size = 2066757 }, -] - -[[package]] -name = "pydantic-settings" -version = "2.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pydantic" }, - { name = "python-dotenv" }, - { name = "typing-inspection" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/67/1d/42628a2c33e93f8e9acbde0d5d735fa0850f3e6a2f8cb1eb6c40b9a732ac/pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268", size = 163234 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/5f/d6d641b490fd3ec2c4c13b4244d68deea3a1b970a97be64f34fb5504ff72/pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef", size = 44356 }, -] - -[[package]] -name = "pyright" -version = "1.1.400" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nodeenv" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/6c/cb/c306618a02d0ee8aed5fb8d0fe0ecfed0dbf075f71468f03a30b5f4e1fe0/pyright-1.1.400.tar.gz", hash = "sha256:b8a3ba40481aa47ba08ffb3228e821d22f7d391f83609211335858bf05686bdb", size = 3846546 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/a5/5d285e4932cf149c90e3c425610c5efaea005475d5f96f1bfdb452956c62/pyright-1.1.400-py3-none-any.whl", hash = "sha256:c80d04f98b5a4358ad3a35e241dbf2a408eee33a40779df365644f8054d2517e", size = 5563460 }, -] - -[[package]] -name = "pytest" -version = "8.3.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 }, -] - -[[package]] -name = "python-dotenv" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/88/2c/7bb1416c5620485aa793f2de31d3df393d3686aa8a8506d11e10e13c5baf/python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5", size = 39920 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, -] - -[[package]] -name = "python-multipart" -version = "0.0.20" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 }, -] - -[[package]] -name = "ruff" -version = "0.11.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/4c/4a3c5a97faaae6b428b336dcca81d03ad04779f8072c267ad2bd860126bf/ruff-0.11.10.tar.gz", hash = "sha256:d522fb204b4959909ecac47da02830daec102eeb100fb50ea9554818d47a5fa6", size = 4165632 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/9f/596c628f8824a2ce4cd12b0f0b4c0629a62dfffc5d0f742c19a1d71be108/ruff-0.11.10-py3-none-linux_armv6l.whl", hash = "sha256:859a7bfa7bc8888abbea31ef8a2b411714e6a80f0d173c2a82f9041ed6b50f58", size = 10316243 }, - { url = "https://files.pythonhosted.org/packages/3c/38/c1e0b77ab58b426f8c332c1d1d3432d9fc9a9ea622806e208220cb133c9e/ruff-0.11.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:968220a57e09ea5e4fd48ed1c646419961a0570727c7e069842edd018ee8afed", size = 11083636 }, - { url = "https://files.pythonhosted.org/packages/23/41/b75e15961d6047d7fe1b13886e56e8413be8467a4e1be0a07f3b303cd65a/ruff-0.11.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1067245bad978e7aa7b22f67113ecc6eb241dca0d9b696144256c3a879663bca", size = 10441624 }, - { url = "https://files.pythonhosted.org/packages/b6/2c/e396b6703f131406db1811ea3d746f29d91b41bbd43ad572fea30da1435d/ruff-0.11.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4854fd09c7aed5b1590e996a81aeff0c9ff51378b084eb5a0b9cd9518e6cff2", size = 10624358 }, - { url = "https://files.pythonhosted.org/packages/bd/8c/ee6cca8bdaf0f9a3704796022851a33cd37d1340bceaf4f6e991eb164e2e/ruff-0.11.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b4564e9f99168c0f9195a0fd5fa5928004b33b377137f978055e40008a082c5", size = 10176850 }, - { url = "https://files.pythonhosted.org/packages/e9/ce/4e27e131a434321b3b7c66512c3ee7505b446eb1c8a80777c023f7e876e6/ruff-0.11.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b6a9cc5b62c03cc1fea0044ed8576379dbaf751d5503d718c973d5418483641", size = 11759787 }, - { url = "https://files.pythonhosted.org/packages/58/de/1e2e77fc72adc7cf5b5123fd04a59ed329651d3eab9825674a9e640b100b/ruff-0.11.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:607ecbb6f03e44c9e0a93aedacb17b4eb4f3563d00e8b474298a201622677947", size = 12430479 }, - { url = "https://files.pythonhosted.org/packages/07/ed/af0f2340f33b70d50121628ef175523cc4c37619e98d98748c85764c8d88/ruff-0.11.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3a522fa389402cd2137df9ddefe848f727250535c70dafa840badffb56b7a4", size = 11919760 }, - { url = "https://files.pythonhosted.org/packages/24/09/d7b3d3226d535cb89234390f418d10e00a157b6c4a06dfbe723e9322cb7d/ruff-0.11.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f071b0deed7e9245d5820dac235cbdd4ef99d7b12ff04c330a241ad3534319f", size = 14041747 }, - { url = "https://files.pythonhosted.org/packages/62/b3/a63b4e91850e3f47f78795e6630ee9266cb6963de8f0191600289c2bb8f4/ruff-0.11.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a60e3a0a617eafba1f2e4186d827759d65348fa53708ca547e384db28406a0b", size = 11550657 }, - { url = "https://files.pythonhosted.org/packages/46/63/a4f95c241d79402ccdbdb1d823d156c89fbb36ebfc4289dce092e6c0aa8f/ruff-0.11.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:da8ec977eaa4b7bf75470fb575bea2cb41a0e07c7ea9d5a0a97d13dbca697bf2", size = 10489671 }, - { url = "https://files.pythonhosted.org/packages/6a/9b/c2238bfebf1e473495659c523d50b1685258b6345d5ab0b418ca3f010cd7/ruff-0.11.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ddf8967e08227d1bd95cc0851ef80d2ad9c7c0c5aab1eba31db49cf0a7b99523", size = 10160135 }, - { url = "https://files.pythonhosted.org/packages/ba/ef/ba7251dd15206688dbfba7d413c0312e94df3b31b08f5d695580b755a899/ruff-0.11.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5a94acf798a82db188f6f36575d80609072b032105d114b0f98661e1679c9125", size = 11170179 }, - { url = "https://files.pythonhosted.org/packages/73/9f/5c336717293203ba275dbfa2ea16e49b29a9fd9a0ea8b6febfc17e133577/ruff-0.11.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3afead355f1d16d95630df28d4ba17fb2cb9c8dfac8d21ced14984121f639bad", size = 11626021 }, - { url = "https://files.pythonhosted.org/packages/d9/2b/162fa86d2639076667c9aa59196c020dc6d7023ac8f342416c2f5ec4bda0/ruff-0.11.10-py3-none-win32.whl", hash = "sha256:dc061a98d32a97211af7e7f3fa1d4ca2fcf919fb96c28f39551f35fc55bdbc19", size = 10494958 }, - { url = "https://files.pythonhosted.org/packages/24/f3/66643d8f32f50a4b0d09a4832b7d919145ee2b944d43e604fbd7c144d175/ruff-0.11.10-py3-none-win_amd64.whl", hash = "sha256:5cc725fbb4d25b0f185cb42df07ab6b76c4489b4bfb740a175f3a59c70e8a224", size = 11650285 }, - { url = "https://files.pythonhosted.org/packages/95/3a/2e8704d19f376c799748ff9cb041225c1d59f3e7711bc5596c8cfdc24925/ruff-0.11.10-py3-none-win_arm64.whl", hash = "sha256:ef69637b35fb8b210743926778d0e45e1bffa850a7c61e428c6b971549b5f5d1", size = 10765278 }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, -] - -[[package]] -name = "sse-starlette" -version = "2.3.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "starlette" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/10/5f/28f45b1ff14bee871bacafd0a97213f7ec70e389939a80c60c0fb72a9fc9/sse_starlette-2.3.5.tar.gz", hash = "sha256:228357b6e42dcc73a427990e2b4a03c023e2495ecee82e14f07ba15077e334b2", size = 17511 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/48/3e49cf0f64961656402c0023edbc51844fe17afe53ab50e958a6dbbbd499/sse_starlette-2.3.5-py3-none-any.whl", hash = "sha256:251708539a335570f10eaaa21d1848a10c42ee6dc3a9cf37ef42266cdb1c52a8", size = 10233 }, -] - -[[package]] -name = "starlette" -version = "0.46.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, -] - -[[package]] -name = "tomli" -version = "2.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 }, - { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 }, - { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 }, - { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 }, - { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 }, - { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 }, - { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 }, - { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 }, - { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 }, - { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 }, - { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 }, - { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 }, - { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 }, - { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 }, - { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 }, - { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 }, - { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 }, - { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 }, - { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 }, - { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 }, - { url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708 }, - { url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582 }, - { url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543 }, - { url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691 }, - { url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170 }, - { url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530 }, - { url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666 }, - { url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954 }, - { url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724 }, - { url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383 }, - { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, -] - -[[package]] -name = "typing-extensions" -version = "4.13.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef", size = 106967 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 }, -] - -[[package]] -name = "typing-inspection" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, -] - -[[package]] -name = "uvicorn" -version = "0.34.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "h11" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a6/ae/9bbb19b9e1c450cf9ecaef06463e40234d98d95bf572fab11b4f19ae5ded/uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328", size = 76815 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/4b/4cef6ce21a2aaca9d852a6e84ef4f135d99fcd74fa75105e2fc0c8308acd/uvicorn-0.34.2-py3-none-any.whl", hash = "sha256:deb49af569084536d269fe0a6d67e3754f104cf03aba7c11c40f01aadf33c403", size = 62483 }, -] diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 1a30578b64..78a81a4d9f 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -401,7 +401,7 @@ async def start(self) -> None: await self.cleanup_servers() -async def main() -> None: +async def run() -> None: """Initialize and run the chat session.""" config = Configuration() server_config = config.load_config("servers_config.json") @@ -411,5 +411,9 @@ async def main() -> None: await chat_session.start() +def main() -> None: + asyncio.run(run()) + + if __name__ == "__main__": - asyncio.run(main()) + main() diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/requirements.txt b/examples/clients/simple-chatbot/mcp_simple_chatbot/requirements.txt index c01e1576c2..2292072ffa 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/requirements.txt +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/requirements.txt @@ -1,4 +1,4 @@ python-dotenv>=1.0.0 requests>=2.31.0 mcp>=1.0.0 -uvicorn>=0.32.1 \ No newline at end of file +uvicorn>=0.32.1 diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/servers_config.json b/examples/clients/simple-chatbot/mcp_simple_chatbot/servers_config.json index 98f8e1fd56..3a92d05d1e 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/servers_config.json +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/servers_config.json @@ -9,4 +9,4 @@ "args": ["-y", "@modelcontextprotocol/server-puppeteer"] } } -} \ No newline at end of file +} diff --git a/examples/clients/simple-chatbot/pyproject.toml b/examples/clients/simple-chatbot/pyproject.toml index b699ecc32a..564b42df33 100644 --- a/examples/clients/simple-chatbot/pyproject.toml +++ b/examples/clients/simple-chatbot/pyproject.toml @@ -17,12 +17,12 @@ classifiers = [ dependencies = [ "python-dotenv>=1.0.0", "requests>=2.31.0", - "mcp>=1.0.0", - "uvicorn>=0.32.1" + "mcp", + "uvicorn>=0.32.1", ] [project.scripts] -mcp-simple-chatbot = "mcp_simple_chatbot.client:main" +mcp-simple-chatbot = "mcp_simple_chatbot.main:main" [build-system] requires = ["hatchling"] @@ -44,5 +44,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] +[dependency-groups] +dev = ["pyright>=1.1.379", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/clients/simple-chatbot/uv.lock b/examples/clients/simple-chatbot/uv.lock deleted file mode 100644 index ee7cb2fab7..0000000000 --- a/examples/clients/simple-chatbot/uv.lock +++ /dev/null @@ -1,555 +0,0 @@ -version = 1 -requires-python = ">=3.10" - -[[package]] -name = "annotated-types" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, -] - -[[package]] -name = "anyio" -version = "4.8.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a3/73/199a98fc2dae33535d6b8e8e6ec01f8c1d76c9adb096c6b7d64823038cde/anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a", size = 181126 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 }, -] - -[[package]] -name = "certifi" -version = "2024.12.14" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0f/bd/1d41ee578ce09523c81a15426705dd20969f5abf006d1afe8aeff0dd776a/certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db", size = 166010 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/32/8f6669fc4798494966bf446c8c4a162e0b5d893dff088afddf76414f70e1/certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56", size = 164927 }, -] - -[[package]] -name = "charset-normalizer" -version = "3.4.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0d/58/5580c1716040bc89206c77d8f74418caf82ce519aae06450393ca73475d1/charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de", size = 198013 }, - { url = "https://files.pythonhosted.org/packages/d0/11/00341177ae71c6f5159a08168bcb98c6e6d196d372c94511f9f6c9afe0c6/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176", size = 141285 }, - { url = "https://files.pythonhosted.org/packages/01/09/11d684ea5819e5a8f5100fb0b38cf8d02b514746607934134d31233e02c8/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037", size = 151449 }, - { url = "https://files.pythonhosted.org/packages/08/06/9f5a12939db324d905dc1f70591ae7d7898d030d7662f0d426e2286f68c9/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f", size = 143892 }, - { url = "https://files.pythonhosted.org/packages/93/62/5e89cdfe04584cb7f4d36003ffa2936681b03ecc0754f8e969c2becb7e24/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a", size = 146123 }, - { url = "https://files.pythonhosted.org/packages/a9/ac/ab729a15c516da2ab70a05f8722ecfccc3f04ed7a18e45c75bbbaa347d61/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a", size = 147943 }, - { url = "https://files.pythonhosted.org/packages/03/d2/3f392f23f042615689456e9a274640c1d2e5dd1d52de36ab8f7955f8f050/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247", size = 142063 }, - { url = "https://files.pythonhosted.org/packages/f2/e3/e20aae5e1039a2cd9b08d9205f52142329f887f8cf70da3650326670bddf/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408", size = 150578 }, - { url = "https://files.pythonhosted.org/packages/8d/af/779ad72a4da0aed925e1139d458adc486e61076d7ecdcc09e610ea8678db/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb", size = 153629 }, - { url = "https://files.pythonhosted.org/packages/c2/b6/7aa450b278e7aa92cf7732140bfd8be21f5f29d5bf334ae987c945276639/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d", size = 150778 }, - { url = "https://files.pythonhosted.org/packages/39/f4/d9f4f712d0951dcbfd42920d3db81b00dd23b6ab520419626f4023334056/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807", size = 146453 }, - { url = "https://files.pythonhosted.org/packages/49/2b/999d0314e4ee0cff3cb83e6bc9aeddd397eeed693edb4facb901eb8fbb69/charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f", size = 95479 }, - { url = "https://files.pythonhosted.org/packages/2d/ce/3cbed41cff67e455a386fb5e5dd8906cdda2ed92fbc6297921f2e4419309/charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f", size = 102790 }, - { url = "https://files.pythonhosted.org/packages/72/80/41ef5d5a7935d2d3a773e3eaebf0a9350542f2cab4eac59a7a4741fbbbbe/charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125", size = 194995 }, - { url = "https://files.pythonhosted.org/packages/7a/28/0b9fefa7b8b080ec492110af6d88aa3dea91c464b17d53474b6e9ba5d2c5/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1", size = 139471 }, - { url = "https://files.pythonhosted.org/packages/71/64/d24ab1a997efb06402e3fc07317e94da358e2585165930d9d59ad45fcae2/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3", size = 149831 }, - { url = "https://files.pythonhosted.org/packages/37/ed/be39e5258e198655240db5e19e0b11379163ad7070962d6b0c87ed2c4d39/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd", size = 142335 }, - { url = "https://files.pythonhosted.org/packages/88/83/489e9504711fa05d8dde1574996408026bdbdbd938f23be67deebb5eca92/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00", size = 143862 }, - { url = "https://files.pythonhosted.org/packages/c6/c7/32da20821cf387b759ad24627a9aca289d2822de929b8a41b6241767b461/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12", size = 145673 }, - { url = "https://files.pythonhosted.org/packages/68/85/f4288e96039abdd5aeb5c546fa20a37b50da71b5cf01e75e87f16cd43304/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77", size = 140211 }, - { url = "https://files.pythonhosted.org/packages/28/a3/a42e70d03cbdabc18997baf4f0227c73591a08041c149e710045c281f97b/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146", size = 148039 }, - { url = "https://files.pythonhosted.org/packages/85/e4/65699e8ab3014ecbe6f5c71d1a55d810fb716bbfd74f6283d5c2aa87febf/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd", size = 151939 }, - { url = "https://files.pythonhosted.org/packages/b1/82/8e9fe624cc5374193de6860aba3ea8070f584c8565ee77c168ec13274bd2/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6", size = 149075 }, - { url = "https://files.pythonhosted.org/packages/3d/7b/82865ba54c765560c8433f65e8acb9217cb839a9e32b42af4aa8e945870f/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8", size = 144340 }, - { url = "https://files.pythonhosted.org/packages/b5/b6/9674a4b7d4d99a0d2df9b215da766ee682718f88055751e1e5e753c82db0/charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b", size = 95205 }, - { url = "https://files.pythonhosted.org/packages/1e/ab/45b180e175de4402dcf7547e4fb617283bae54ce35c27930a6f35b6bef15/charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76", size = 102441 }, - { url = "https://files.pythonhosted.org/packages/0a/9a/dd1e1cdceb841925b7798369a09279bd1cf183cef0f9ddf15a3a6502ee45/charset_normalizer-3.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545", size = 196105 }, - { url = "https://files.pythonhosted.org/packages/d3/8c/90bfabf8c4809ecb648f39794cf2a84ff2e7d2a6cf159fe68d9a26160467/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7", size = 140404 }, - { url = "https://files.pythonhosted.org/packages/ad/8f/e410d57c721945ea3b4f1a04b74f70ce8fa800d393d72899f0a40526401f/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757", size = 150423 }, - { url = "https://files.pythonhosted.org/packages/f0/b8/e6825e25deb691ff98cf5c9072ee0605dc2acfca98af70c2d1b1bc75190d/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa", size = 143184 }, - { url = "https://files.pythonhosted.org/packages/3e/a2/513f6cbe752421f16d969e32f3583762bfd583848b763913ddab8d9bfd4f/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d", size = 145268 }, - { url = "https://files.pythonhosted.org/packages/74/94/8a5277664f27c3c438546f3eb53b33f5b19568eb7424736bdc440a88a31f/charset_normalizer-3.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616", size = 147601 }, - { url = "https://files.pythonhosted.org/packages/7c/5f/6d352c51ee763623a98e31194823518e09bfa48be2a7e8383cf691bbb3d0/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b", size = 141098 }, - { url = "https://files.pythonhosted.org/packages/78/d4/f5704cb629ba5ab16d1d3d741396aec6dc3ca2b67757c45b0599bb010478/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d", size = 149520 }, - { url = "https://files.pythonhosted.org/packages/c5/96/64120b1d02b81785f222b976c0fb79a35875457fa9bb40827678e54d1bc8/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a", size = 152852 }, - { url = "https://files.pythonhosted.org/packages/84/c9/98e3732278a99f47d487fd3468bc60b882920cef29d1fa6ca460a1fdf4e6/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9", size = 150488 }, - { url = "https://files.pythonhosted.org/packages/13/0e/9c8d4cb99c98c1007cc11eda969ebfe837bbbd0acdb4736d228ccaabcd22/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1", size = 146192 }, - { url = "https://files.pythonhosted.org/packages/b2/21/2b6b5b860781a0b49427309cb8670785aa543fb2178de875b87b9cc97746/charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35", size = 95550 }, - { url = "https://files.pythonhosted.org/packages/21/5b/1b390b03b1d16c7e382b561c5329f83cc06623916aab983e8ab9239c7d5c/charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f", size = 102785 }, - { url = "https://files.pythonhosted.org/packages/38/94/ce8e6f63d18049672c76d07d119304e1e2d7c6098f0841b51c666e9f44a0/charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda", size = 195698 }, - { url = "https://files.pythonhosted.org/packages/24/2e/dfdd9770664aae179a96561cc6952ff08f9a8cd09a908f259a9dfa063568/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313", size = 140162 }, - { url = "https://files.pythonhosted.org/packages/24/4e/f646b9093cff8fc86f2d60af2de4dc17c759de9d554f130b140ea4738ca6/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9", size = 150263 }, - { url = "https://files.pythonhosted.org/packages/5e/67/2937f8d548c3ef6e2f9aab0f6e21001056f692d43282b165e7c56023e6dd/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b", size = 142966 }, - { url = "https://files.pythonhosted.org/packages/52/ed/b7f4f07de100bdb95c1756d3a4d17b90c1a3c53715c1a476f8738058e0fa/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11", size = 144992 }, - { url = "https://files.pythonhosted.org/packages/96/2c/d49710a6dbcd3776265f4c923bb73ebe83933dfbaa841c5da850fe0fd20b/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f", size = 147162 }, - { url = "https://files.pythonhosted.org/packages/b4/41/35ff1f9a6bd380303dea55e44c4933b4cc3c4850988927d4082ada230273/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd", size = 140972 }, - { url = "https://files.pythonhosted.org/packages/fb/43/c6a0b685fe6910d08ba971f62cd9c3e862a85770395ba5d9cad4fede33ab/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2", size = 149095 }, - { url = "https://files.pythonhosted.org/packages/4c/ff/a9a504662452e2d2878512115638966e75633519ec11f25fca3d2049a94a/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886", size = 152668 }, - { url = "https://files.pythonhosted.org/packages/6c/71/189996b6d9a4b932564701628af5cee6716733e9165af1d5e1b285c530ed/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601", size = 150073 }, - { url = "https://files.pythonhosted.org/packages/e4/93/946a86ce20790e11312c87c75ba68d5f6ad2208cfb52b2d6a2c32840d922/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd", size = 145732 }, - { url = "https://files.pythonhosted.org/packages/cd/e5/131d2fb1b0dddafc37be4f3a2fa79aa4c037368be9423061dccadfd90091/charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407", size = 95391 }, - { url = "https://files.pythonhosted.org/packages/27/f2/4f9a69cc7712b9b5ad8fdb87039fd89abba997ad5cbe690d1835d40405b0/charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971", size = 102702 }, - { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, -] - -[[package]] -name = "click" -version = "8.1.8" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, -] - -[[package]] -name = "exceptiongroup" -version = "1.2.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, -] - -[[package]] -name = "h11" -version = "0.14.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, -] - -[[package]] -name = "httpcore" -version = "1.0.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, -] - -[[package]] -name = "httpx-sse" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, -] - -[[package]] -name = "iniconfig" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, -] - -[[package]] -name = "mcp" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "httpx" }, - { name = "httpx-sse" }, - { name = "pydantic" }, - { name = "pydantic-settings" }, - { name = "sse-starlette" }, - { name = "starlette" }, - { name = "uvicorn" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ab/a5/b08dc846ebedae9f17ced878e6975826e90e448cd4592f532f6a88a925a7/mcp-1.2.0.tar.gz", hash = "sha256:2b06c7ece98d6ea9e6379caa38d74b432385c338fb530cb82e2c70ea7add94f5", size = 102973 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/af/84/fca78f19ac8ce6c53ba416247c71baa53a9e791e98d3c81edbc20a77d6d1/mcp-1.2.0-py3-none-any.whl", hash = "sha256:1d0e77d8c14955a5aea1f5aa1f444c8e531c09355c829b20e42f7a142bc0755f", size = 66468 }, -] - -[[package]] -name = "mcp-simple-chatbot" -version = "0.1.0" -source = { editable = "." } -dependencies = [ - { name = "mcp" }, - { name = "python-dotenv" }, - { name = "requests" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "pytest" }, - { name = "ruff" }, -] - -[package.metadata] -requires-dist = [ - { name = "mcp", specifier = ">=1.0.0" }, - { name = "python-dotenv", specifier = ">=1.0.0" }, - { name = "requests", specifier = ">=2.31.0" }, - { name = "uvicorn", specifier = ">=0.32.1" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.379" }, - { name = "pytest", specifier = ">=8.3.3" }, - { name = "ruff", specifier = ">=0.6.9" }, -] - -[[package]] -name = "nodeenv" -version = "1.9.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, -] - -[[package]] -name = "packaging" -version = "24.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, -] - -[[package]] -name = "pluggy" -version = "1.5.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, -] - -[[package]] -name = "pydantic" -version = "2.10.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "annotated-types" }, - { name = "pydantic-core" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/6a/c7/ca334c2ef6f2e046b1144fe4bb2a5da8a4c574e7f2ebf7e16b34a6a2fa92/pydantic-2.10.5.tar.gz", hash = "sha256:278b38dbbaec562011d659ee05f63346951b3a248a6f3642e1bc68894ea2b4ff", size = 761287 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/58/26/82663c79010b28eddf29dcdd0ea723439535fa917fce5905885c0e9ba562/pydantic-2.10.5-py3-none-any.whl", hash = "sha256:4dd4e322dbe55472cb7ca7e73f4b63574eecccf2835ffa2af9021ce113c83c53", size = 431426 }, -] - -[[package]] -name = "pydantic-core" -version = "2.27.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fc/01/f3e5ac5e7c25833db5eb555f7b7ab24cd6f8c322d3a3ad2d67a952dc0abc/pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39", size = 413443 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/bc/fed5f74b5d802cf9a03e83f60f18864e90e3aed7223adaca5ffb7a8d8d64/pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa", size = 1895938 }, - { url = "https://files.pythonhosted.org/packages/71/2a/185aff24ce844e39abb8dd680f4e959f0006944f4a8a0ea372d9f9ae2e53/pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c", size = 1815684 }, - { url = "https://files.pythonhosted.org/packages/c3/43/fafabd3d94d159d4f1ed62e383e264f146a17dd4d48453319fd782e7979e/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7969e133a6f183be60e9f6f56bfae753585680f3b7307a8e555a948d443cc05a", size = 1829169 }, - { url = "https://files.pythonhosted.org/packages/a2/d1/f2dfe1a2a637ce6800b799aa086d079998959f6f1215eb4497966efd2274/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3de9961f2a346257caf0aa508a4da705467f53778e9ef6fe744c038119737ef5", size = 1867227 }, - { url = "https://files.pythonhosted.org/packages/7d/39/e06fcbcc1c785daa3160ccf6c1c38fea31f5754b756e34b65f74e99780b5/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2bb4d3e5873c37bb3dd58714d4cd0b0e6238cebc4177ac8fe878f8b3aa8e74c", size = 2037695 }, - { url = "https://files.pythonhosted.org/packages/7a/67/61291ee98e07f0650eb756d44998214231f50751ba7e13f4f325d95249ab/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:280d219beebb0752699480fe8f1dc61ab6615c2046d76b7ab7ee38858de0a4e7", size = 2741662 }, - { url = "https://files.pythonhosted.org/packages/32/90/3b15e31b88ca39e9e626630b4c4a1f5a0dfd09076366f4219429e6786076/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47956ae78b6422cbd46f772f1746799cbb862de838fd8d1fbd34a82e05b0983a", size = 1993370 }, - { url = "https://files.pythonhosted.org/packages/ff/83/c06d333ee3a67e2e13e07794995c1535565132940715931c1c43bfc85b11/pydantic_core-2.27.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:14d4a5c49d2f009d62a2a7140d3064f686d17a5d1a268bc641954ba181880236", size = 1996813 }, - { url = "https://files.pythonhosted.org/packages/7c/f7/89be1c8deb6e22618a74f0ca0d933fdcb8baa254753b26b25ad3acff8f74/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:337b443af21d488716f8d0b6164de833e788aa6bd7e3a39c005febc1284f4962", size = 2005287 }, - { url = "https://files.pythonhosted.org/packages/b7/7d/8eb3e23206c00ef7feee17b83a4ffa0a623eb1a9d382e56e4aa46fd15ff2/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:03d0f86ea3184a12f41a2d23f7ccb79cdb5a18e06993f8a45baa8dfec746f0e9", size = 2128414 }, - { url = "https://files.pythonhosted.org/packages/4e/99/fe80f3ff8dd71a3ea15763878d464476e6cb0a2db95ff1c5c554133b6b83/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7041c36f5680c6e0f08d922aed302e98b3745d97fe1589db0a3eebf6624523af", size = 2155301 }, - { url = "https://files.pythonhosted.org/packages/2b/a3/e50460b9a5789ca1451b70d4f52546fa9e2b420ba3bfa6100105c0559238/pydantic_core-2.27.2-cp310-cp310-win32.whl", hash = "sha256:50a68f3e3819077be2c98110c1f9dcb3817e93f267ba80a2c05bb4f8799e2ff4", size = 1816685 }, - { url = "https://files.pythonhosted.org/packages/57/4c/a8838731cb0f2c2a39d3535376466de6049034d7b239c0202a64aaa05533/pydantic_core-2.27.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0fd26b16394ead34a424eecf8a31a1f5137094cabe84a1bcb10fa6ba39d3d31", size = 1982876 }, - { url = "https://files.pythonhosted.org/packages/c2/89/f3450af9d09d44eea1f2c369f49e8f181d742f28220f88cc4dfaae91ea6e/pydantic_core-2.27.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8e10c99ef58cfdf2a66fc15d66b16c4a04f62bca39db589ae8cba08bc55331bc", size = 1893421 }, - { url = "https://files.pythonhosted.org/packages/9e/e3/71fe85af2021f3f386da42d291412e5baf6ce7716bd7101ea49c810eda90/pydantic_core-2.27.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:26f32e0adf166a84d0cb63be85c562ca8a6fa8de28e5f0d92250c6b7e9e2aff7", size = 1814998 }, - { url = "https://files.pythonhosted.org/packages/a6/3c/724039e0d848fd69dbf5806894e26479577316c6f0f112bacaf67aa889ac/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c19d1ea0673cd13cc2f872f6c9ab42acc4e4f492a7ca9d3795ce2b112dd7e15", size = 1826167 }, - { url = "https://files.pythonhosted.org/packages/2b/5b/1b29e8c1fb5f3199a9a57c1452004ff39f494bbe9bdbe9a81e18172e40d3/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e68c4446fe0810e959cdff46ab0a41ce2f2c86d227d96dc3847af0ba7def306", size = 1865071 }, - { url = "https://files.pythonhosted.org/packages/89/6c/3985203863d76bb7d7266e36970d7e3b6385148c18a68cc8915fd8c84d57/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9640b0059ff4f14d1f37321b94061c6db164fbe49b334b31643e0528d100d99", size = 2036244 }, - { url = "https://files.pythonhosted.org/packages/0e/41/f15316858a246b5d723f7d7f599f79e37493b2e84bfc789e58d88c209f8a/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40d02e7d45c9f8af700f3452f329ead92da4c5f4317ca9b896de7ce7199ea459", size = 2737470 }, - { url = "https://files.pythonhosted.org/packages/a8/7c/b860618c25678bbd6d1d99dbdfdf0510ccb50790099b963ff78a124b754f/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1fd185014191700554795c99b347d64f2bb637966c4cfc16998a0ca700d048", size = 1992291 }, - { url = "https://files.pythonhosted.org/packages/bf/73/42c3742a391eccbeab39f15213ecda3104ae8682ba3c0c28069fbcb8c10d/pydantic_core-2.27.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d81d2068e1c1228a565af076598f9e7451712700b673de8f502f0334f281387d", size = 1994613 }, - { url = "https://files.pythonhosted.org/packages/94/7a/941e89096d1175d56f59340f3a8ebaf20762fef222c298ea96d36a6328c5/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a4207639fb02ec2dbb76227d7c751a20b1a6b4bc52850568e52260cae64ca3b", size = 2002355 }, - { url = "https://files.pythonhosted.org/packages/6e/95/2359937a73d49e336a5a19848713555605d4d8d6940c3ec6c6c0ca4dcf25/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:3de3ce3c9ddc8bbd88f6e0e304dea0e66d843ec9de1b0042b0911c1663ffd474", size = 2126661 }, - { url = "https://files.pythonhosted.org/packages/2b/4c/ca02b7bdb6012a1adef21a50625b14f43ed4d11f1fc237f9d7490aa5078c/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:30c5f68ded0c36466acede341551106821043e9afaad516adfb6e8fa80a4e6a6", size = 2153261 }, - { url = "https://files.pythonhosted.org/packages/72/9d/a241db83f973049a1092a079272ffe2e3e82e98561ef6214ab53fe53b1c7/pydantic_core-2.27.2-cp311-cp311-win32.whl", hash = "sha256:c70c26d2c99f78b125a3459f8afe1aed4d9687c24fd677c6a4436bc042e50d6c", size = 1812361 }, - { url = "https://files.pythonhosted.org/packages/e8/ef/013f07248041b74abd48a385e2110aa3a9bbfef0fbd97d4e6d07d2f5b89a/pydantic_core-2.27.2-cp311-cp311-win_amd64.whl", hash = "sha256:08e125dbdc505fa69ca7d9c499639ab6407cfa909214d500897d02afb816e7cc", size = 1982484 }, - { url = "https://files.pythonhosted.org/packages/10/1c/16b3a3e3398fd29dca77cea0a1d998d6bde3902fa2706985191e2313cc76/pydantic_core-2.27.2-cp311-cp311-win_arm64.whl", hash = "sha256:26f0d68d4b235a2bae0c3fc585c585b4ecc51382db0e3ba402a22cbc440915e4", size = 1867102 }, - { url = "https://files.pythonhosted.org/packages/d6/74/51c8a5482ca447871c93e142d9d4a92ead74de6c8dc5e66733e22c9bba89/pydantic_core-2.27.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9e0c8cfefa0ef83b4da9588448b6d8d2a2bf1a53c3f1ae5fca39eb3061e2f0b0", size = 1893127 }, - { url = "https://files.pythonhosted.org/packages/d3/f3/c97e80721735868313c58b89d2de85fa80fe8dfeeed84dc51598b92a135e/pydantic_core-2.27.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83097677b8e3bd7eaa6775720ec8e0405f1575015a463285a92bfdfe254529ef", size = 1811340 }, - { url = "https://files.pythonhosted.org/packages/9e/91/840ec1375e686dbae1bd80a9e46c26a1e0083e1186abc610efa3d9a36180/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:172fce187655fece0c90d90a678424b013f8fbb0ca8b036ac266749c09438cb7", size = 1822900 }, - { url = "https://files.pythonhosted.org/packages/f6/31/4240bc96025035500c18adc149aa6ffdf1a0062a4b525c932065ceb4d868/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:519f29f5213271eeeeb3093f662ba2fd512b91c5f188f3bb7b27bc5973816934", size = 1869177 }, - { url = "https://files.pythonhosted.org/packages/fa/20/02fbaadb7808be578317015c462655c317a77a7c8f0ef274bc016a784c54/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05e3a55d124407fffba0dd6b0c0cd056d10e983ceb4e5dbd10dda135c31071d6", size = 2038046 }, - { url = "https://files.pythonhosted.org/packages/06/86/7f306b904e6c9eccf0668248b3f272090e49c275bc488a7b88b0823444a4/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c3ed807c7b91de05e63930188f19e921d1fe90de6b4f5cd43ee7fcc3525cb8c", size = 2685386 }, - { url = "https://files.pythonhosted.org/packages/8d/f0/49129b27c43396581a635d8710dae54a791b17dfc50c70164866bbf865e3/pydantic_core-2.27.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fb4aadc0b9a0c063206846d603b92030eb6f03069151a625667f982887153e2", size = 1997060 }, - { url = "https://files.pythonhosted.org/packages/0d/0f/943b4af7cd416c477fd40b187036c4f89b416a33d3cc0ab7b82708a667aa/pydantic_core-2.27.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:28ccb213807e037460326424ceb8b5245acb88f32f3d2777427476e1b32c48c4", size = 2004870 }, - { url = "https://files.pythonhosted.org/packages/35/40/aea70b5b1a63911c53a4c8117c0a828d6790483f858041f47bab0b779f44/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:de3cd1899e2c279b140adde9357c4495ed9d47131b4a4eaff9052f23398076b3", size = 1999822 }, - { url = "https://files.pythonhosted.org/packages/f2/b3/807b94fd337d58effc5498fd1a7a4d9d59af4133e83e32ae39a96fddec9d/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:220f892729375e2d736b97d0e51466252ad84c51857d4d15f5e9692f9ef12be4", size = 2130364 }, - { url = "https://files.pythonhosted.org/packages/fc/df/791c827cd4ee6efd59248dca9369fb35e80a9484462c33c6649a8d02b565/pydantic_core-2.27.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a0fcd29cd6b4e74fe8ddd2c90330fd8edf2e30cb52acda47f06dd615ae72da57", size = 2158303 }, - { url = "https://files.pythonhosted.org/packages/9b/67/4e197c300976af185b7cef4c02203e175fb127e414125916bf1128b639a9/pydantic_core-2.27.2-cp312-cp312-win32.whl", hash = "sha256:1e2cb691ed9834cd6a8be61228471d0a503731abfb42f82458ff27be7b2186fc", size = 1834064 }, - { url = "https://files.pythonhosted.org/packages/1f/ea/cd7209a889163b8dcca139fe32b9687dd05249161a3edda62860430457a5/pydantic_core-2.27.2-cp312-cp312-win_amd64.whl", hash = "sha256:cc3f1a99a4f4f9dd1de4fe0312c114e740b5ddead65bb4102884b384c15d8bc9", size = 1989046 }, - { url = "https://files.pythonhosted.org/packages/bc/49/c54baab2f4658c26ac633d798dab66b4c3a9bbf47cff5284e9c182f4137a/pydantic_core-2.27.2-cp312-cp312-win_arm64.whl", hash = "sha256:3911ac9284cd8a1792d3cb26a2da18f3ca26c6908cc434a18f730dc0db7bfa3b", size = 1885092 }, - { url = "https://files.pythonhosted.org/packages/41/b1/9bc383f48f8002f99104e3acff6cba1231b29ef76cfa45d1506a5cad1f84/pydantic_core-2.27.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7d14bd329640e63852364c306f4d23eb744e0f8193148d4044dd3dacdaacbd8b", size = 1892709 }, - { url = "https://files.pythonhosted.org/packages/10/6c/e62b8657b834f3eb2961b49ec8e301eb99946245e70bf42c8817350cbefc/pydantic_core-2.27.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82f91663004eb8ed30ff478d77c4d1179b3563df6cdb15c0817cd1cdaf34d154", size = 1811273 }, - { url = "https://files.pythonhosted.org/packages/ba/15/52cfe49c8c986e081b863b102d6b859d9defc63446b642ccbbb3742bf371/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71b24c7d61131bb83df10cc7e687433609963a944ccf45190cfc21e0887b08c9", size = 1823027 }, - { url = "https://files.pythonhosted.org/packages/b1/1c/b6f402cfc18ec0024120602bdbcebc7bdd5b856528c013bd4d13865ca473/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fa8e459d4954f608fa26116118bb67f56b93b209c39b008277ace29937453dc9", size = 1868888 }, - { url = "https://files.pythonhosted.org/packages/bd/7b/8cb75b66ac37bc2975a3b7de99f3c6f355fcc4d89820b61dffa8f1e81677/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce8918cbebc8da707ba805b7fd0b382816858728ae7fe19a942080c24e5b7cd1", size = 2037738 }, - { url = "https://files.pythonhosted.org/packages/c8/f1/786d8fe78970a06f61df22cba58e365ce304bf9b9f46cc71c8c424e0c334/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eda3f5c2a021bbc5d976107bb302e0131351c2ba54343f8a496dc8783d3d3a6a", size = 2685138 }, - { url = "https://files.pythonhosted.org/packages/a6/74/d12b2cd841d8724dc8ffb13fc5cef86566a53ed358103150209ecd5d1999/pydantic_core-2.27.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd8086fa684c4775c27f03f062cbb9eaa6e17f064307e86b21b9e0abc9c0f02e", size = 1997025 }, - { url = "https://files.pythonhosted.org/packages/a0/6e/940bcd631bc4d9a06c9539b51f070b66e8f370ed0933f392db6ff350d873/pydantic_core-2.27.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8d9b3388db186ba0c099a6d20f0604a44eabdeef1777ddd94786cdae158729e4", size = 2004633 }, - { url = "https://files.pythonhosted.org/packages/50/cc/a46b34f1708d82498c227d5d80ce615b2dd502ddcfd8376fc14a36655af1/pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7a66efda2387de898c8f38c0cf7f14fca0b51a8ef0b24bfea5849f1b3c95af27", size = 1999404 }, - { url = "https://files.pythonhosted.org/packages/ca/2d/c365cfa930ed23bc58c41463bae347d1005537dc8db79e998af8ba28d35e/pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:18a101c168e4e092ab40dbc2503bdc0f62010e95d292b27827871dc85450d7ee", size = 2130130 }, - { url = "https://files.pythonhosted.org/packages/f4/d7/eb64d015c350b7cdb371145b54d96c919d4db516817f31cd1c650cae3b21/pydantic_core-2.27.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ba5dd002f88b78a4215ed2f8ddbdf85e8513382820ba15ad5ad8955ce0ca19a1", size = 2157946 }, - { url = "https://files.pythonhosted.org/packages/a4/99/bddde3ddde76c03b65dfd5a66ab436c4e58ffc42927d4ff1198ffbf96f5f/pydantic_core-2.27.2-cp313-cp313-win32.whl", hash = "sha256:1ebaf1d0481914d004a573394f4be3a7616334be70261007e47c2a6fe7e50130", size = 1834387 }, - { url = "https://files.pythonhosted.org/packages/71/47/82b5e846e01b26ac6f1893d3c5f9f3a2eb6ba79be26eef0b759b4fe72946/pydantic_core-2.27.2-cp313-cp313-win_amd64.whl", hash = "sha256:953101387ecf2f5652883208769a79e48db18c6df442568a0b5ccd8c2723abee", size = 1990453 }, - { url = "https://files.pythonhosted.org/packages/51/b2/b2b50d5ecf21acf870190ae5d093602d95f66c9c31f9d5de6062eb329ad1/pydantic_core-2.27.2-cp313-cp313-win_arm64.whl", hash = "sha256:ac4dbfd1691affb8f48c2c13241a2e3b60ff23247cbcf981759c768b6633cf8b", size = 1885186 }, - { url = "https://files.pythonhosted.org/packages/46/72/af70981a341500419e67d5cb45abe552a7c74b66326ac8877588488da1ac/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e", size = 1891159 }, - { url = "https://files.pythonhosted.org/packages/ad/3d/c5913cccdef93e0a6a95c2d057d2c2cba347815c845cda79ddd3c0f5e17d/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8", size = 1768331 }, - { url = "https://files.pythonhosted.org/packages/f6/f0/a3ae8fbee269e4934f14e2e0e00928f9346c5943174f2811193113e58252/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3", size = 1822467 }, - { url = "https://files.pythonhosted.org/packages/d7/7a/7bbf241a04e9f9ea24cd5874354a83526d639b02674648af3f350554276c/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d591580c34f4d731592f0e9fe40f9cc1b430d297eecc70b962e93c5c668f15f", size = 1979797 }, - { url = "https://files.pythonhosted.org/packages/4f/5f/4784c6107731f89e0005a92ecb8a2efeafdb55eb992b8e9d0a2be5199335/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82f986faf4e644ffc189a7f1aafc86e46ef70372bb153e7001e8afccc6e54133", size = 1987839 }, - { url = "https://files.pythonhosted.org/packages/6d/a7/61246562b651dff00de86a5f01b6e4befb518df314c54dec187a78d81c84/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bec317a27290e2537f922639cafd54990551725fc844249e64c523301d0822fc", size = 1998861 }, - { url = "https://files.pythonhosted.org/packages/86/aa/837821ecf0c022bbb74ca132e117c358321e72e7f9702d1b6a03758545e2/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50", size = 2116582 }, - { url = "https://files.pythonhosted.org/packages/81/b0/5e74656e95623cbaa0a6278d16cf15e10a51f6002e3ec126541e95c29ea3/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9", size = 2151985 }, - { url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 }, -] - -[[package]] -name = "pydantic-settings" -version = "2.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pydantic" }, - { name = "python-dotenv" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/73/7b/c58a586cd7d9ac66d2ee4ba60ca2d241fa837c02bca9bea80a9a8c3d22a9/pydantic_settings-2.7.1.tar.gz", hash = "sha256:10c9caad35e64bfb3c2fbf70a078c0e25cc92499782e5200747f942a065dec93", size = 79920 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/46/93416fdae86d40879714f72956ac14df9c7b76f7d41a4d68aa9f71a0028b/pydantic_settings-2.7.1-py3-none-any.whl", hash = "sha256:590be9e6e24d06db33a4262829edef682500ef008565a969c73d39d5f8bfb3fd", size = 29718 }, -] - -[[package]] -name = "pyright" -version = "1.1.392.post0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "nodeenv" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/66/df/3c6f6b08fba7ccf49b114dfc4bb33e25c299883fd763f93fad47ef8bc58d/pyright-1.1.392.post0.tar.gz", hash = "sha256:3b7f88de74a28dcfa90c7d90c782b6569a48c2be5f9d4add38472bdaac247ebd", size = 3789911 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/b1/a18de17f40e4f61ca58856b9ef9b0febf74ff88978c3f7776f910071f567/pyright-1.1.392.post0-py3-none-any.whl", hash = "sha256:252f84458a46fa2f0fd4e2f91fc74f50b9ca52c757062e93f6c250c0d8329eb2", size = 5595487 }, -] - -[[package]] -name = "pytest" -version = "8.3.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, -] - -[[package]] -name = "python-dotenv" -version = "1.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, -] - -[[package]] -name = "requests" -version = "2.32.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "charset-normalizer" }, - { name = "idna" }, - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, -] - -[[package]] -name = "ruff" -version = "0.9.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/80/63/77ecca9d21177600f551d1c58ab0e5a0b260940ea7312195bd2a4798f8a8/ruff-0.9.2.tar.gz", hash = "sha256:b5eceb334d55fae5f316f783437392642ae18e16dcf4f1858d55d3c2a0f8f5d0", size = 3553799 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/af/b9/0e168e4e7fb3af851f739e8f07889b91d1a33a30fca8c29fa3149d6b03ec/ruff-0.9.2-py3-none-linux_armv6l.whl", hash = "sha256:80605a039ba1454d002b32139e4970becf84b5fee3a3c3bf1c2af6f61a784347", size = 11652408 }, - { url = "https://files.pythonhosted.org/packages/2c/22/08ede5db17cf701372a461d1cb8fdde037da1d4fa622b69ac21960e6237e/ruff-0.9.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b9aab82bb20afd5f596527045c01e6ae25a718ff1784cb92947bff1f83068b00", size = 11587553 }, - { url = "https://files.pythonhosted.org/packages/42/05/dedfc70f0bf010230229e33dec6e7b2235b2a1b8cbb2a991c710743e343f/ruff-0.9.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbd337bac1cfa96be615f6efcd4bc4d077edbc127ef30e2b8ba2a27e18c054d4", size = 11020755 }, - { url = "https://files.pythonhosted.org/packages/df/9b/65d87ad9b2e3def67342830bd1af98803af731243da1255537ddb8f22209/ruff-0.9.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b35259b0cbf8daa22a498018e300b9bb0174c2bbb7bcba593935158a78054d", size = 11826502 }, - { url = "https://files.pythonhosted.org/packages/93/02/f2239f56786479e1a89c3da9bc9391120057fc6f4a8266a5b091314e72ce/ruff-0.9.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b6a9701d1e371bf41dca22015c3f89769da7576884d2add7317ec1ec8cb9c3c", size = 11390562 }, - { url = "https://files.pythonhosted.org/packages/c9/37/d3a854dba9931f8cb1b2a19509bfe59e00875f48ade632e95aefcb7a0aee/ruff-0.9.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9cc53e68b3c5ae41e8faf83a3b89f4a5d7b2cb666dff4b366bb86ed2a85b481f", size = 12548968 }, - { url = "https://files.pythonhosted.org/packages/fa/c3/c7b812bb256c7a1d5553433e95980934ffa85396d332401f6b391d3c4569/ruff-0.9.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8efd9da7a1ee314b910da155ca7e8953094a7c10d0c0a39bfde3fcfd2a015684", size = 13187155 }, - { url = "https://files.pythonhosted.org/packages/bd/5a/3c7f9696a7875522b66aa9bba9e326e4e5894b4366bd1dc32aa6791cb1ff/ruff-0.9.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3292c5a22ea9a5f9a185e2d131dc7f98f8534a32fb6d2ee7b9944569239c648d", size = 12704674 }, - { url = "https://files.pythonhosted.org/packages/be/d6/d908762257a96ce5912187ae9ae86792e677ca4f3dc973b71e7508ff6282/ruff-0.9.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a605fdcf6e8b2d39f9436d343d1f0ff70c365a1e681546de0104bef81ce88df", size = 14529328 }, - { url = "https://files.pythonhosted.org/packages/2d/c2/049f1e6755d12d9cd8823242fa105968f34ee4c669d04cac8cea51a50407/ruff-0.9.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547f7f256aa366834829a08375c297fa63386cbe5f1459efaf174086b564247", size = 12385955 }, - { url = "https://files.pythonhosted.org/packages/91/5a/a9bdb50e39810bd9627074e42743b00e6dc4009d42ae9f9351bc3dbc28e7/ruff-0.9.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d18bba3d3353ed916e882521bc3e0af403949dbada344c20c16ea78f47af965e", size = 11810149 }, - { url = "https://files.pythonhosted.org/packages/e5/fd/57df1a0543182f79a1236e82a79c68ce210efb00e97c30657d5bdb12b478/ruff-0.9.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b338edc4610142355ccf6b87bd356729b62bf1bc152a2fad5b0c7dc04af77bfe", size = 11479141 }, - { url = "https://files.pythonhosted.org/packages/dc/16/bc3fd1d38974f6775fc152a0554f8c210ff80f2764b43777163c3c45d61b/ruff-0.9.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:492a5e44ad9b22a0ea98cf72e40305cbdaf27fac0d927f8bc9e1df316dcc96eb", size = 12014073 }, - { url = "https://files.pythonhosted.org/packages/47/6b/e4ca048a8f2047eb652e1e8c755f384d1b7944f69ed69066a37acd4118b0/ruff-0.9.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:af1e9e9fe7b1f767264d26b1075ac4ad831c7db976911fa362d09b2d0356426a", size = 12435758 }, - { url = "https://files.pythonhosted.org/packages/c2/40/4d3d6c979c67ba24cf183d29f706051a53c36d78358036a9cd21421582ab/ruff-0.9.2-py3-none-win32.whl", hash = "sha256:71cbe22e178c5da20e1514e1e01029c73dc09288a8028a5d3446e6bba87a5145", size = 9796916 }, - { url = "https://files.pythonhosted.org/packages/c3/ef/7f548752bdb6867e6939489c87fe4da489ab36191525fadc5cede2a6e8e2/ruff-0.9.2-py3-none-win_amd64.whl", hash = "sha256:c5e1d6abc798419cf46eed03f54f2e0c3adb1ad4b801119dedf23fcaf69b55b5", size = 10773080 }, - { url = "https://files.pythonhosted.org/packages/0e/4e/33df635528292bd2d18404e4daabcd74ca8a9853b2e1df85ed3d32d24362/ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6", size = 10001738 }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, -] - -[[package]] -name = "sse-starlette" -version = "2.2.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "starlette" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/71/a4/80d2a11af59fe75b48230846989e93979c892d3a20016b42bb44edb9e398/sse_starlette-2.2.1.tar.gz", hash = "sha256:54470d5f19274aeed6b2d473430b08b4b379ea851d953b11d7f1c4a2c118b419", size = 17376 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/e0/5b8bd393f27f4a62461c5cf2479c75a2cc2ffa330976f9f00f5f6e4f50eb/sse_starlette-2.2.1-py3-none-any.whl", hash = "sha256:6410a3d3ba0c89e7675d4c273a301d64649c03a5ef1ca101f10b47f895fd0e99", size = 10120 }, -] - -[[package]] -name = "starlette" -version = "0.45.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/90/4f/e1c9f4ec3dae67a94c9285ed275355d5f7cf0f3a5c34538c8ae5412af550/starlette-0.45.2.tar.gz", hash = "sha256:bba1831d15ae5212b22feab2f218bab6ed3cd0fc2dc1d4442443bb1ee52260e0", size = 2574026 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/ab/fe4f57c83620b39dfc9e7687ebad59129ff05170b99422105019d9a65eec/starlette-0.45.2-py3-none-any.whl", hash = "sha256:4daec3356fb0cb1e723a5235e5beaf375d2259af27532958e2d79df549dad9da", size = 71505 }, -] - -[[package]] -name = "tomli" -version = "2.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 }, - { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 }, - { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 }, - { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 }, - { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 }, - { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 }, - { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 }, - { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 }, - { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 }, - { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 }, - { url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 }, - { url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 }, - { url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 }, - { url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 }, - { url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 }, - { url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 }, - { url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 }, - { url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 }, - { url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 }, - { url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 }, - { url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708 }, - { url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582 }, - { url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543 }, - { url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691 }, - { url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170 }, - { url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530 }, - { url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666 }, - { url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954 }, - { url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724 }, - { url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383 }, - { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, -] - -[[package]] -name = "typing-extensions" -version = "4.12.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, -] - -[[package]] -name = "urllib3" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, -] - -[[package]] -name = "uvicorn" -version = "0.34.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "h11" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 }, -] diff --git a/examples/clients/simple-task-client/README.md b/examples/clients/simple-task-client/README.md new file mode 100644 index 0000000000..103be0f1fb --- /dev/null +++ b/examples/clients/simple-task-client/README.md @@ -0,0 +1,43 @@ +# Simple Task Client + +A minimal MCP client demonstrating polling for task results over streamable HTTP. + +## Running + +First, start the simple-task server in another terminal: + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +Then run the client: + +```bash +cd examples/clients/simple-task-client +uv run mcp-simple-task-client +``` + +Use `--url` to connect to a different server. + +## What it does + +1. Connects to the server via streamable HTTP +2. Calls the `long_running_task` tool as a task +3. Polls the task status until completion +4. Retrieves and prints the result + +## Expected output + +```text +Available tools: ['long_running_task'] + +Calling tool as a task... +Task created: + Status: working - Starting work... + Status: working - Processing step 1... + Status: working - Processing step 2... + Status: completed - + +Result: Task completed! +``` diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py new file mode 100644 index 0000000000..2fc2cda8d9 --- /dev/null +++ b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .main import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py new file mode 100644 index 0000000000..12691162ab --- /dev/null +++ b/examples/clients/simple-task-client/mcp_simple_task_client/main.py @@ -0,0 +1,55 @@ +"""Simple task client demonstrating MCP tasks polling over streamable HTTP.""" + +import asyncio + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import CallToolResult, TextContent + + +async def run(url: str) -> None: + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + + # List tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Call the tool as a task + print("\nCalling tool as a task...") + + result = await session.experimental.call_tool_as_task( + "long_running_task", + arguments={}, + ttl=60000, + ) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # Poll until done (respects server's pollInterval hint) + async for status in session.experimental.poll_task(task_id): + print(f" Status: {status.status} - {status.statusMessage or ''}") + + # Check final status + if status.status != "completed": + print(f"Task ended with status: {status.status}") + return + + # Get the result + task_result = await session.experimental.get_task_result(task_id, CallToolResult) + content = task_result.content[0] + if isinstance(content, TextContent): + print(f"\nResult: {content.text}") + + +@click.command() +@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") +def main(url: str) -> int: + asyncio.run(run(url)) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/clients/simple-task-client/pyproject.toml b/examples/clients/simple-task-client/pyproject.toml new file mode 100644 index 0000000000..da10392e3c --- /dev/null +++ b/examples/clients/simple-task-client/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-client" +version = "0.1.0" +description = "A simple MCP client demonstrating task polling" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "client"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.0", "mcp"] + +[project.scripts] +mcp-simple-task-client = "mcp_simple_task_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_client"] + +[tool.pyright] +include = ["mcp_simple_task_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/clients/simple-task-interactive-client/README.md b/examples/clients/simple-task-interactive-client/README.md new file mode 100644 index 0000000000..ac73d2bc12 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/README.md @@ -0,0 +1,87 @@ +# Simple Interactive Task Client + +A minimal MCP client demonstrating responses to interactive tasks (elicitation and sampling). + +## Running + +First, start the interactive task server in another terminal: + +```bash +cd examples/servers/simple-task-interactive +uv run mcp-simple-task-interactive +``` + +Then run the client: + +```bash +cd examples/clients/simple-task-interactive-client +uv run mcp-simple-task-interactive-client +``` + +Use `--url` to connect to a different server. + +## What it does + +1. Connects to the server via streamable HTTP +2. Calls `confirm_delete` - server asks for confirmation, client responds via terminal +3. Calls `write_haiku` - server requests LLM completion, client returns a hardcoded haiku + +## Key concepts + +### Elicitation callback + +```python +async def elicitation_callback(context, params) -> ElicitResult: + # Handle user input request from server + return ElicitResult(action="http://github.com/accept", content={"confirm": True}) +``` + +### Sampling callback + +```python +async def sampling_callback(context, params) -> CreateMessageResult: + # Handle LLM completion request from server + return CreateMessageResult(model="...", role="assistant", content=...) +``` + +### Using call_tool_as_task + +```python +# Call a tool as a task (returns immediately with task reference) +result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) +task_id = result.task.taskId + +# Get result - this delivers elicitation/sampling requests and blocks until complete +final = await session.experimental.get_task_result(task_id, CallToolResult) +``` + +**Important**: The `get_task_result()` call is what triggers the delivery of elicitation +and sampling requests to your callbacks. It blocks until the task completes and returns +the final result. + +## Expected output + +```text +Available tools: ['confirm_delete', 'write_haiku'] + +--- Demo 1: Elicitation --- +Calling confirm_delete tool... +Task created: + +[Elicitation] Server asks: Are you sure you want to delete 'important.txt'? +Your response (y/n): y +[Elicitation] Responding with: confirm=True +Result: Deleted 'important.txt' + +--- Demo 2: Sampling --- +Calling write_haiku tool... +Task created: + +[Sampling] Server requests LLM completion for: Write a haiku about autumn leaves +[Sampling] Responding with haiku +Result: +Haiku: +Cherry blossoms fall +Softly on the quiet pond +Spring whispers goodbye +``` diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py new file mode 100644 index 0000000000..2fc2cda8d9 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .main import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py new file mode 100644 index 0000000000..a8a47dc57c --- /dev/null +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -0,0 +1,138 @@ +"""Simple interactive task client demonstrating elicitation and sampling responses. + +This example demonstrates the spec-compliant polling pattern: +1. Poll tasks/get watching for status changes +2. On input_required, call tasks/result to receive elicitation/sampling requests +3. Continue until terminal status, then retrieve final result +""" + +import asyncio +from typing import Any + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.context import RequestContext +from mcp.types import ( + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestParams, + ElicitResult, + TextContent, +) + + +async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, +) -> ElicitResult: + """Handle elicitation requests from the server.""" + print(f"\n[Elicitation] Server asks: {params.message}") + + # Simple terminal prompt + response = input("Your response (y/n): ").strip().lower() + confirmed = response in ("y", "yes", "true", "1") + + print(f"[Elicitation] Responding with: confirm={confirmed}") + return ElicitResult(action="http://github.com/accept", content={"confirm": confirmed}) + + +async def sampling_callback( + context: RequestContext[ClientSession, Any], + params: CreateMessageRequestParams, +) -> CreateMessageResult: + """Handle sampling requests from the server.""" + # Get the prompt from the first message + prompt = "unknown" + if params.messages: + content = params.messages[0].content + if isinstance(content, TextContent): + prompt = content.text + + print(f"\n[Sampling] Server requests LLM completion for: {prompt}") + + # Return a hardcoded haiku (in real use, call your LLM here) + haiku = """Cherry blossoms fall +Softly on the quiet pond +Spring whispers goodbye""" + + print("[Sampling] Responding with haiku") + return CreateMessageResult( + model="mock-haiku-model", + role="assistant", + content=TextContent(type="text", text=haiku), + ) + + +def get_text(result: CallToolResult) -> str: + """Extract text from a CallToolResult.""" + if result.content and isinstance(result.content[0], TextContent): + return result.content[0].text + return "(no text)" + + +async def run(url: str) -> None: + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession( + read, + write, + elicitation_callback=elicitation_callback, + sampling_callback=sampling_callback, + ) as session: + await session.initialize() + + # List tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Demo 1: Elicitation (confirm_delete) + print("\n--- Demo 1: Elicitation ---") + print("Calling confirm_delete tool...") + + elicit_task = await session.experimental.call_tool_as_task("confirm_delete", {"filename": "important.txt"}) + elicit_task_id = elicit_task.task.taskId + print(f"Task created: {elicit_task_id}") + + # Poll until terminal, calling tasks/result on input_required + async for status in session.experimental.poll_task(elicit_task_id): + print(f"[Poll] Status: {status.status}") + if status.status == "input_required": + # Server needs input - tasks/result delivers the elicitation request + elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) + break + else: + # poll_task exited due to terminal status + elicit_result = await session.experimental.get_task_result(elicit_task_id, CallToolResult) + + print(f"Result: {get_text(elicit_result)}") + + # Demo 2: Sampling (write_haiku) + print("\n--- Demo 2: Sampling ---") + print("Calling write_haiku tool...") + + sampling_task = await session.experimental.call_tool_as_task("write_haiku", {"topic": "autumn leaves"}) + sampling_task_id = sampling_task.task.taskId + print(f"Task created: {sampling_task_id}") + + # Poll until terminal, calling tasks/result on input_required + async for status in session.experimental.poll_task(sampling_task_id): + print(f"[Poll] Status: {status.status}") + if status.status == "input_required": + sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) + break + else: + sampling_result = await session.experimental.get_task_result(sampling_task_id, CallToolResult) + + print(f"Result:\n{get_text(sampling_result)}") + + +@click.command() +@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") +def main(url: str) -> int: + asyncio.run(run(url)) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/clients/simple-task-interactive-client/pyproject.toml b/examples/clients/simple-task-interactive-client/pyproject.toml new file mode 100644 index 0000000000..224bbc5917 --- /dev/null +++ b/examples/clients/simple-task-interactive-client/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-interactive-client" +version = "0.1.0" +description = "A simple MCP client demonstrating interactive task responses" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "client", "elicitation", "sampling"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.0", "mcp"] + +[project.scripts] +mcp-simple-task-interactive-client = "mcp_simple_task_interactive_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_interactive_client"] + +[tool.pyright] +include = ["mcp_simple_task_interactive_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/clients/sse-polling-client/README.md b/examples/clients/sse-polling-client/README.md new file mode 100644 index 0000000000..78449aa832 --- /dev/null +++ b/examples/clients/sse-polling-client/README.md @@ -0,0 +1,30 @@ +# MCP SSE Polling Demo Client + +Demonstrates client-side auto-reconnect for the SSE polling pattern (SEP-1699). + +## Features + +- Connects to SSE polling demo server +- Automatically reconnects when server closes SSE stream +- Resumes from Last-Event-ID to avoid missing messages +- Respects server-provided retry interval + +## Usage + +```bash +# First start the server: +uv run mcp-sse-polling-demo --port 3000 + +# Then run this client: +uv run mcp-sse-polling-client --url http://localhost:3000/mcp + +# Custom options: +uv run mcp-sse-polling-client --url http://localhost:3000/mcp --items 20 --checkpoint-every 5 +``` + +## Options + +- `--url`: Server URL (default: ) +- `--items`: Number of items to process (default: 10) +- `--checkpoint-every`: Checkpoint interval (default: 3) +- `--log-level`: Logging level (default: DEBUG) diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py new file mode 100644 index 0000000000..ee69b32c96 --- /dev/null +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/__init__.py @@ -0,0 +1 @@ +"""SSE Polling Demo Client - demonstrates auto-reconnect for long-running tasks.""" diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py new file mode 100644 index 0000000000..1defd8eaa4 --- /dev/null +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py @@ -0,0 +1,105 @@ +""" +SSE Polling Demo Client + +Demonstrates the client-side auto-reconnect for SSE polling pattern. + +This client connects to the SSE Polling Demo server and calls process_batch, +which triggers periodic server-side stream closes. The client automatically +reconnects using Last-Event-ID and resumes receiving messages. + +Run with: + # First start the server: + uv run mcp-sse-polling-demo --port 3000 + + # Then run this client: + uv run mcp-sse-polling-client --url http://localhost:3000/mcp +""" + +import asyncio +import logging + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + +logger = logging.getLogger(__name__) + + +async def run_demo(url: str, items: int, checkpoint_every: int) -> None: + """Run the SSE polling demo.""" + print(f"\n{'=' * 60}") + print("SSE Polling Demo Client") + print(f"{'=' * 60}") + print(f"Server URL: {url}") + print(f"Processing {items} items with checkpoints every {checkpoint_every}") + print(f"{'=' * 60}\n") + + async with streamablehttp_client(url) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the connection + print("Initializing connection...") + await session.initialize() + print("Connected!\n") + + # List available tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}\n") + + # Call the process_batch tool + print(f"Calling process_batch(items={items}, checkpoint_every={checkpoint_every})...\n") + print("-" * 40) + + result = await session.call_tool( + "process_batch", + { + "items": items, + "checkpoint_every": checkpoint_every, + }, + ) + + print("-" * 40) + if result.content: + content = result.content[0] + text = getattr(content, "text", str(content)) + print(f"\nResult: {text}") + else: + print("\nResult: No content") + print(f"{'=' * 60}\n") + + +@click.command() +@click.option( + "--url", + default="http://localhost:3000/mcp", + help="Server URL", +) +@click.option( + "--items", + default=10, + help="Number of items to process", +) +@click.option( + "--checkpoint-every", + default=3, + help="Checkpoint interval", +) +@click.option( + "--log-level", + default="INFO", + help="Logging level", +) +def main(url: str, items: int, checkpoint_every: int, log_level: str) -> None: + """Run the SSE Polling Demo client.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + # Suppress noisy HTTP client logging + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + asyncio.run(run_demo(url, items, checkpoint_every)) + + +if __name__ == "__main__": + main() diff --git a/examples/clients/sse-polling-client/pyproject.toml b/examples/clients/sse-polling-client/pyproject.toml new file mode 100644 index 0000000000..ae896708d4 --- /dev/null +++ b/examples/clients/sse-polling-client/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-sse-polling-client" +version = "0.1.0" +description = "Demo client for SSE polling with auto-reconnect" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "sse", "polling", "client"] +license = { text = "MIT" } +dependencies = ["click>=8.2.0", "mcp"] + +[project.scripts] +mcp-sse-polling-client = "mcp_sse_polling_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_sse_polling_client"] + +[tool.pyright] +include = ["mcp_sse_polling_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/fastmcp/direct_call_tool_result_return.py b/examples/fastmcp/direct_call_tool_result_return.py new file mode 100644 index 0000000000..a441769b2a --- /dev/null +++ b/examples/fastmcp/direct_call_tool_result_return.py @@ -0,0 +1,24 @@ +""" +FastMCP Echo Server with direct CallToolResult return +""" + +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server.fastmcp import FastMCP +from mcp.types import CallToolResult, TextContent + +mcp = FastMCP("Echo Server") + + +class EchoResponse(BaseModel): + text: str + + +@mcp.tool() +def echo(text: str) -> Annotated[CallToolResult, EchoResponse]: + """Echo the input text with structure and metadata""" + return CallToolResult( + content=[TextContent(type="text", text=text)], structuredContent={"text": text}, _meta={"some": "metadata"} + ) diff --git a/examples/servers/everything-server/README.md b/examples/servers/everything-server/README.md new file mode 100644 index 0000000000..3512665cb9 --- /dev/null +++ b/examples/servers/everything-server/README.md @@ -0,0 +1,42 @@ +# MCP Everything Server + +A comprehensive MCP server implementing all protocol features for conformance testing. + +## Overview + +The Everything Server is a reference implementation that demonstrates all features of the Model Context Protocol (MCP). It is designed to be used with the [MCP Conformance Test Framework](https://github.com/modelcontextprotocol/conformance) to validate MCP client and server implementations. + +## Installation + +From the python-sdk root directory: + +```bash +uv sync --frozen +``` + +## Usage + +### Running the Server + +Start the server with default settings (port 3001): + +```bash +uv run -m mcp_everything_server +``` + +Or with custom options: + +```bash +uv run -m mcp_everything_server --port 3001 --log-level DEBUG +``` + +The server will be available at: `http://localhost:3001/mcp` + +### Command-Line Options + +- `--port` - Port to listen on (default: 3001) +- `--log-level` - Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL (default: INFO) + +## Running Conformance Tests + +See the [MCP Conformance Test Framework](https://github.com/modelcontextprotocol/conformance) for instructions on running conformance tests against this server. diff --git a/examples/servers/everything-server/mcp_everything_server/__init__.py b/examples/servers/everything-server/mcp_everything_server/__init__.py new file mode 100644 index 0000000000..d539062d4f --- /dev/null +++ b/examples/servers/everything-server/mcp_everything_server/__init__.py @@ -0,0 +1,3 @@ +"""MCP Everything Server - Comprehensive conformance test server.""" + +__version__ = "0.1.0" diff --git a/examples/servers/everything-server/mcp_everything_server/__main__.py b/examples/servers/everything-server/mcp_everything_server/__main__.py new file mode 100644 index 0000000000..2eff688f02 --- /dev/null +++ b/examples/servers/everything-server/mcp_everything_server/__main__.py @@ -0,0 +1,6 @@ +"""CLI entry point for the MCP Everything Server.""" + +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py new file mode 100644 index 0000000000..1f1ee7ecc4 --- /dev/null +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +""" +MCP Everything Server - Conformance Test Server + +Server implementing all MCP features for conformance testing based on Conformance Server Specification. +""" + +import asyncio +import base64 +import json +import logging + +import click +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.fastmcp.prompts.base import UserMessage +from mcp.server.session import ServerSession +from mcp.server.streamable_http import EventCallback, EventMessage, EventStore +from mcp.types import ( + AudioContent, + Completion, + CompletionArgument, + CompletionContext, + EmbeddedResource, + ImageContent, + JSONRPCMessage, + PromptReference, + ResourceTemplateReference, + SamplingMessage, + TextContent, + TextResourceContents, +) +from pydantic import AnyUrl, BaseModel, Field + +logger = logging.getLogger(__name__) + +# Type aliases for event store +StreamId = str +EventId = str + + +class InMemoryEventStore(EventStore): + """Simple in-memory event store for SSE resumability testing.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, EventId, JSONRPCMessage | None]] = [] + self._event_id_counter = 0 + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + """Store an event and return its ID.""" + self._event_id_counter += 1 + event_id = str(self._event_id_counter) + self._events.append((stream_id, event_id, message)) + return event_id + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + """Replay events after the specified ID.""" + target_stream_id = None + for stream_id, event_id, _ in self._events: + if event_id == last_event_id: + target_stream_id = stream_id + break + if target_stream_id is None: + return None + last_event_id_int = int(last_event_id) + for stream_id, event_id, message in self._events: + if stream_id == target_stream_id and int(event_id) > last_event_id_int: + # Skip priming events (None message) + if message is not None: + await send_callback(EventMessage(message, event_id)) + return target_stream_id + + +# Test data +TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" +TEST_AUDIO_BASE64 = "UklGRiYAAABXQVZFZm10IBAAAAABAAEAQB8AAAB9AAACABAAZGF0YQIAAAA=" + +# Server state +resource_subscriptions: set[str] = set() +watched_resource_content = "Watched resource content" + +# Create event store for SSE resumability (SEP-1699) +event_store = InMemoryEventStore() + +mcp = FastMCP( + name="mcp-conformance-test-server", + event_store=event_store, + retry_interval=100, # 100ms retry interval for SSE polling +) + + +# Tools +@mcp.tool() +def test_simple_text() -> str: + """Tests simple text content response""" + return "This is a simple text response for testing." + + +@mcp.tool() +def test_image_content() -> list[ImageContent]: + """Tests image content response""" + return [ImageContent(type="image", data=TEST_IMAGE_BASE64, mimeType="image/png")] + + +@mcp.tool() +def test_audio_content() -> list[AudioContent]: + """Tests audio content response""" + return [AudioContent(type="audio", data=TEST_AUDIO_BASE64, mimeType="audio/wav")] + + +@mcp.tool() +def test_embedded_resource() -> list[EmbeddedResource]: + """Tests embedded resource content response""" + return [ + EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=AnyUrl("test://embedded-resource"), + mimeType="text/plain", + text="This is an embedded resource content.", + ), + ) + ] + + +@mcp.tool() +def test_multiple_content_types() -> list[TextContent | ImageContent | EmbeddedResource]: + """Tests response with multiple content types (text, image, resource)""" + return [ + TextContent(type="text", text="Multiple content types test:"), + ImageContent(type="image", data=TEST_IMAGE_BASE64, mimeType="image/png"), + EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=AnyUrl("test://mixed-content-resource"), + mimeType="application/json", + text='{"test": "data", "value": 123}', + ), + ), + ] + + +@mcp.tool() +async def test_tool_with_logging(ctx: Context[ServerSession, None]) -> str: + """Tests tool that emits log messages during execution""" + await ctx.info("Tool execution started") + await asyncio.sleep(0.05) + + await ctx.info("Tool processing data") + await asyncio.sleep(0.05) + + await ctx.info("Tool execution completed") + return "Tool with logging executed successfully" + + +@mcp.tool() +async def test_tool_with_progress(ctx: Context[ServerSession, None]) -> str: + """Tests tool that reports progress notifications""" + await ctx.report_progress(progress=0, total=100, message="Completed step 0 of 100") + await asyncio.sleep(0.05) + + await ctx.report_progress(progress=50, total=100, message="Completed step 50 of 100") + await asyncio.sleep(0.05) + + await ctx.report_progress(progress=100, total=100, message="Completed step 100 of 100") + + # Return progress token as string + progress_token = ctx.request_context.meta.progressToken if ctx.request_context and ctx.request_context.meta else 0 + return str(progress_token) + + +@mcp.tool() +async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str: + """Tests server-initiated sampling (LLM completion request)""" + try: + # Request sampling from client + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], + max_tokens=100, + ) + + # Since we're not passing tools param, result.content is single content + if result.content.type == "text": + model_response = result.content.text + else: + model_response = "No response" + + return f"LLM response: {model_response}" + except Exception as e: + return f"Sampling not supported or error: {str(e)}" + + +class UserResponse(BaseModel): + response: str = Field(description="User's response") + + +@mcp.tool() +async def test_elicitation(message: str, ctx: Context[ServerSession, None]) -> str: + """Tests server-initiated elicitation (user input request)""" + try: + # Request user input from client + result = await ctx.elicit(message=message, schema=UserResponse) + + # Type-safe discriminated union narrowing using action field + if result.action == "accept": + content = result.data.model_dump_json() + else: # decline or cancel + content = "{}" + + return f"User response: action={result.action}, content={content}" + except Exception as e: + return f"Elicitation not supported or error: {str(e)}" + + +class SEP1034DefaultsSchema(BaseModel): + """Schema for testing SEP-1034 elicitation with default values for all primitive types""" + + name: str = Field(default="John Doe", description="User name") + age: int = Field(default=30, description="User age") + score: float = Field(default=95.5, description="User score") + status: str = Field( + default="active", + description="User status", + json_schema_extra={"enum": ["active", "inactive", "pending"]}, + ) + verified: bool = Field(default=True, description="Verification status") + + +@mcp.tool() +async def test_elicitation_sep1034_defaults(ctx: Context[ServerSession, None]) -> str: + """Tests elicitation with default values for all primitive types (SEP-1034)""" + try: + # Request user input with defaults for all primitive types + result = await ctx.elicit(message="Please provide user information", schema=SEP1034DefaultsSchema) + + # Type-safe discriminated union narrowing using action field + if result.action == "accept": + content = result.data.model_dump_json() + else: # decline or cancel + content = "{}" + + return f"Elicitation result: action={result.action}, content={content}" + except Exception as e: + return f"Elicitation not supported or error: {str(e)}" + + +class EnumSchemasTestSchema(BaseModel): + """Schema for testing enum schema variations (SEP-1330)""" + + untitledSingle: str = Field( + description="Simple enum without titles", json_schema_extra={"enum": ["active", "inactive", "pending"]} + ) + titledSingle: str = Field( + description="Enum with titled options (oneOf)", + json_schema_extra={ + "oneOf": [ + {"const": "low", "title": "Low Priority"}, + {"const": "medium", "title": "Medium Priority"}, + {"const": "high", "title": "High Priority"}, + ] + }, + ) + untitledMulti: list[str] = Field( + description="Multi-select without titles", + json_schema_extra={"items": {"type": "string", "enum": ["read", "write", "execute"]}}, + ) + titledMulti: list[str] = Field( + description="Multi-select with titled options", + json_schema_extra={ + "items": { + "anyOf": [ + {"const": "feature", "title": "New Feature"}, + {"const": "bug", "title": "Bug Fix"}, + {"const": "docs", "title": "Documentation"}, + ] + } + }, + ) + legacyEnum: str = Field( + description="Legacy enum with enumNames", + json_schema_extra={ + "enum": ["small", "medium", "large"], + "enumNames": ["Small Size", "Medium Size", "Large Size"], + }, + ) + + +@mcp.tool() +async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> str: + """Tests elicitation with enum schema variations per SEP-1330""" + try: + result = await ctx.elicit( + message="Please select values using different enum schema types", schema=EnumSchemasTestSchema + ) + + if result.action == "accept": + content = result.data.model_dump_json() + else: + content = "{}" + + return f"Elicitation completed: action={result.action}, content={content}" + except Exception as e: + return f"Elicitation not supported or error: {str(e)}" + + +@mcp.tool() +def test_error_handling() -> str: + """Tests error response handling""" + raise RuntimeError("This tool intentionally returns an error for testing") + + +@mcp.tool() +async def test_reconnection(ctx: Context[ServerSession, None]) -> str: + """Tests SSE polling by closing stream mid-call (SEP-1699)""" + await ctx.info("Before disconnect") + + await ctx.close_sse_stream() + + await asyncio.sleep(0.2) # Wait for client to reconnect + + await ctx.info("After reconnect") + return "Reconnection test completed" + + +# Resources +@mcp.resource("test://static-text") +def static_text_resource() -> str: + """A static text resource for testing""" + return "This is the content of the static text resource." + + +@mcp.resource("test://static-binary") +def static_binary_resource() -> bytes: + """A static binary resource (image) for testing""" + return base64.b64decode(TEST_IMAGE_BASE64) + + +@mcp.resource("test://template/{id}/data") +def template_resource(id: str) -> str: + """A resource template with parameter substitution""" + return json.dumps({"id": id, "templateTest": True, "data": f"Data for ID: {id}"}) + + +@mcp.resource("test://watched-resource") +def watched_resource() -> str: + """A resource that can be subscribed to for updates""" + return watched_resource_content + + +# Prompts +@mcp.prompt() +def test_simple_prompt() -> list[UserMessage]: + """A simple prompt without arguments""" + return [UserMessage(role="user", content=TextContent(type="text", text="This is a simple prompt for testing."))] + + +@mcp.prompt() +def test_prompt_with_arguments(arg1: str, arg2: str) -> list[UserMessage]: + """A prompt with required arguments""" + return [ + UserMessage( + role="user", content=TextContent(type="text", text=f"Prompt with arguments: arg1='{arg1}', arg2='{arg2}'") + ) + ] + + +@mcp.prompt() +def test_prompt_with_embedded_resource(resourceUri: str) -> list[UserMessage]: + """A prompt that includes an embedded resource""" + return [ + UserMessage( + role="user", + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=AnyUrl(resourceUri), + mimeType="text/plain", + text="Embedded resource content for testing.", + ), + ), + ), + UserMessage(role="user", content=TextContent(type="text", text="Please process the embedded resource above.")), + ] + + +@mcp.prompt() +def test_prompt_with_image() -> list[UserMessage]: + """A prompt that includes image content""" + return [ + UserMessage(role="user", content=ImageContent(type="image", data=TEST_IMAGE_BASE64, mimeType="image/png")), + UserMessage(role="user", content=TextContent(type="text", text="Please analyze the image above.")), + ] + + +# Custom request handlers +# TODO(felix): Add public APIs to FastMCP for subscribe_resource, unsubscribe_resource, +# and set_logging_level to avoid accessing protected _mcp_server attribute. +@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] +async def handle_set_logging_level(level: str) -> None: + """Handle logging level changes""" + logger.info(f"Log level set to: {level}") + # In a real implementation, you would adjust the logging level here + # For conformance testing, we just acknowledge the request + + +async def handle_subscribe(uri: AnyUrl) -> None: + """Handle resource subscription""" + resource_subscriptions.add(str(uri)) + logger.info(f"Subscribed to resource: {uri}") + + +async def handle_unsubscribe(uri: AnyUrl) -> None: + """Handle resource unsubscription""" + resource_subscriptions.discard(str(uri)) + logger.info(f"Unsubscribed from resource: {uri}") + + +mcp._mcp_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] +mcp._mcp_server.unsubscribe_resource()(handle_unsubscribe) # pyright: ignore[reportPrivateUsage] + + +@mcp.completion() +async def _handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, +) -> Completion: + """Handle completion requests""" + # Basic completion support - returns empty array for conformance + # Real implementations would provide contextual suggestions + return Completion(values=[], total=0, hasMore=False) + + +# CLI +@click.command() +@click.option("--port", default=3001, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +def main(port: int, log_level: str) -> int: + """Run the MCP Everything Server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + logger.info(f"Starting MCP Everything Server on port {port}") + logger.info(f"Endpoint will be: http://localhost:{port}/mcp") + + mcp.settings.port = port + mcp.run(transport="streamable-http") + + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/servers/everything-server/pyproject.toml b/examples/servers/everything-server/pyproject.toml new file mode 100644 index 0000000000..ff67bf5577 --- /dev/null +++ b/examples/servers/everything-server/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-everything-server" +version = "0.1.0" +description = "Comprehensive MCP server implementing all protocol features for conformance testing" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "conformance", "testing"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.2.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-everything-server = "mcp_everything_server.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_everything_server"] + +[tool.pyright] +include = ["mcp_everything_server"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index c0a456cd38..5d88505708 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -45,11 +45,6 @@ class ResourceServerSettings(BaseSettings): # RFC 8707 resource validation oauth_strict: bool = False - # TODO(Marcelo): Is this even needed? I didn't have time to check. - def __init__(self, **data: Any): - """Initialize settings with values from environment variables.""" - super().__init__(**data) - def create_resource_server(settings: ResourceServerSettings) -> FastMCP: """ diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 0f1092d7d8..e3a25d3e8c 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -73,6 +73,8 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def register_client(self, client_info: OAuthClientInformationFull): """Register a new OAuth client.""" + if not client_info.client_id: + raise ValueError("No client_id provided") self.clients[client_info.client_id] = client_info async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: @@ -209,6 +211,8 @@ async def exchange_authorization_code( """Exchange authorization code for tokens.""" if authorization_code.code not in self.auth_codes: raise ValueError("Invalid authorization code") + if not client.client_id: + raise ValueError("No client_id provided") # Generate MCP access token mcp_token = f"mcp_{secrets.token_hex(32)}" diff --git a/examples/servers/simple-auth/pyproject.toml b/examples/servers/simple-auth/pyproject.toml index 7a1aeda177..eb2b18561b 100644 --- a/examples/servers/simple-auth/pyproject.toml +++ b/examples/servers/simple-auth/pyproject.toml @@ -29,5 +29,5 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["mcp_simple_auth"] -[tool.uv] -dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] +[dependency-groups] +dev = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] diff --git a/examples/servers/simple-pagination/pyproject.toml b/examples/servers/simple-pagination/pyproject.toml index 0c60cf73c3..14de502574 100644 --- a/examples/servers/simple-pagination/pyproject.toml +++ b/examples/servers/simple-pagination/pyproject.toml @@ -43,5 +43,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/__init__.py b/examples/servers/simple-prompt/mcp_simple_prompt/__init__.py index 8b13789179..e69de29bb2 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/__init__.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/__init__.py @@ -1 +0,0 @@ - diff --git a/examples/servers/simple-prompt/pyproject.toml b/examples/servers/simple-prompt/pyproject.toml index f8cf1a1bef..28fe265746 100644 --- a/examples/servers/simple-prompt/pyproject.toml +++ b/examples/servers/simple-prompt/pyproject.toml @@ -43,5 +43,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-resource/mcp_simple_resource/__init__.py b/examples/servers/simple-resource/mcp_simple_resource/__init__.py index 8b13789179..e69de29bb2 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/__init__.py +++ b/examples/servers/simple-resource/mcp_simple_resource/__init__.py @@ -1 +0,0 @@ - diff --git a/examples/servers/simple-resource/pyproject.toml b/examples/servers/simple-resource/pyproject.toml index c63747f5ec..14c2bd38cc 100644 --- a/examples/servers/simple-resource/pyproject.toml +++ b/examples/servers/simple-resource/pyproject.toml @@ -43,5 +43,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-streamablehttp-stateless/pyproject.toml b/examples/servers/simple-streamablehttp-stateless/pyproject.toml index 41c08b0564..0e695695cb 100644 --- a/examples/servers/simple-streamablehttp-stateless/pyproject.toml +++ b/examples/servers/simple-streamablehttp-stateless/pyproject.toml @@ -32,5 +32,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py index ee52cdbe77..0c3081ed64 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -24,7 +24,7 @@ class EventEntry: event_id: EventId stream_id: StreamId - message: JSONRPCMessage + message: JSONRPCMessage | None class InMemoryEventStore(EventStore): @@ -48,7 +48,7 @@ def __init__(self, max_events_per_stream: int = 100): # event_id -> EventEntry for quick lookup self.event_index: dict[EventId, EventEntry] = {} - async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: """Stores an event with a generated event ID.""" event_id = str(uuid4()) event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) @@ -88,7 +88,9 @@ async def replay_events_after( found_last = False for event in stream_events: if found_last: - await send_callback(EventMessage(event.message, event.event_id)) + # Skip priming events (None message) + if event.message is not None: + await send_callback(EventMessage(event.message, event.event_id)) elif event.event_id == last_event_id: found_last = True diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml index dfc5295fb7..f0404fb7dd 100644 --- a/examples/servers/simple-streamablehttp/pyproject.toml +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -32,5 +32,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task-interactive/README.md b/examples/servers/simple-task-interactive/README.md new file mode 100644 index 0000000000..b8f384cb48 --- /dev/null +++ b/examples/servers/simple-task-interactive/README.md @@ -0,0 +1,74 @@ +# Simple Interactive Task Server + +A minimal MCP server demonstrating interactive tasks with elicitation and sampling. + +## Running + +```bash +cd examples/servers/simple-task-interactive +uv run mcp-simple-task-interactive +``` + +The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. + +## What it does + +This server exposes two tools: + +### `confirm_delete` (demonstrates elicitation) + +Asks the user for confirmation before "deleting" a file. + +- Uses `task.elicit()` to request user input +- Shows the elicitation flow: task -> input_required -> response -> complete + +### `write_haiku` (demonstrates sampling) + +Asks the LLM to write a haiku about a topic. + +- Uses `task.create_message()` to request LLM completion +- Shows the sampling flow: task -> input_required -> response -> complete + +## Usage with the client + +In one terminal, start the server: + +```bash +cd examples/servers/simple-task-interactive +uv run mcp-simple-task-interactive +``` + +In another terminal, run the interactive client: + +```bash +cd examples/clients/simple-task-interactive-client +uv run mcp-simple-task-interactive-client +``` + +## Expected server output + +When a client connects and calls the tools, you'll see: + +```text +Starting server on http://localhost:8000/mcp + +[Server] confirm_delete called for 'important.txt' +[Server] Task created: +[Server] Sending elicitation request to client... +[Server] Received elicitation response: action=accept, content={'confirm': True} +[Server] Completing task with result: Deleted 'important.txt' + +[Server] write_haiku called for topic 'autumn leaves' +[Server] Task created: +[Server] Sending sampling request to client... +[Server] Received sampling response: Cherry blossoms fall +Softly on the quiet pon... +[Server] Completing task with haiku +``` + +## Key concepts + +1. **ServerTaskContext**: Provides `elicit()` and `create_message()` for user interaction +2. **run_task()**: Spawns background work, auto-completes/fails, returns immediately +3. **TaskResultHandler**: Delivers queued messages and routes responses +4. **Response routing**: Responses are routed back to waiting resolvers diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py new file mode 100644 index 0000000000..e7ef16530b --- /dev/null +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py new file mode 100644 index 0000000000..4d35ca8094 --- /dev/null +++ b/examples/servers/simple-task-interactive/mcp_simple_task_interactive/server.py @@ -0,0 +1,147 @@ +"""Simple interactive task server demonstrating elicitation and sampling. + +This example shows the simplified task API where: +- server.experimental.enable_tasks() sets up all infrastructure +- ctx.experimental.run_task() handles task lifecycle automatically +- ServerTaskContext.elicit() and ServerTaskContext.create_message() queue requests properly +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import click +import mcp.types as types +import uvicorn +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + +server = Server("simple-task-interactive") + +# Enable task support - this auto-registers all handlers +server.experimental.enable_tasks() + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="confirm_delete", + description="Asks for confirmation before deleting (demonstrates elicitation)", + inputSchema={ + "type": "object", + "properties": {"filename": {"type": "string"}}, + }, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="write_haiku", + description="Asks LLM to write a haiku (demonstrates sampling)", + inputSchema={"type": "object", "properties": {"topic": {"type": "string"}}}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + ] + + +async def handle_confirm_delete(arguments: dict[str, Any]) -> types.CreateTaskResult: + """Handle the confirm_delete tool - demonstrates elicitation.""" + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + + filename = arguments.get("filename", "unknown.txt") + print(f"\n[Server] confirm_delete called for '{filename}'") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + print(f"[Server] Task {task.task_id} starting elicitation...") + + result = await task.elicit( + message=f"Are you sure you want to delete '{filename}'?", + requestedSchema={ + "type": "object", + "properties": {"confirm": {"type": "boolean"}}, + "required": ["confirm"], + }, + ) + + print(f"[Server] Received elicitation response: action={result.action}, content={result.content}") + + if result.action == "accept" and result.content: + confirmed = result.content.get("confirm", False) + text = f"Deleted '{filename}'" if confirmed else "Deletion cancelled" + else: + text = "Deletion cancelled" + + print(f"[Server] Completing task with result: {text}") + return types.CallToolResult(content=[types.TextContent(type="text", text=text)]) + + return await ctx.experimental.run_task(work) + + +async def handle_write_haiku(arguments: dict[str, Any]) -> types.CreateTaskResult: + """Handle the write_haiku tool - demonstrates sampling.""" + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + + topic = arguments.get("topic", "nature") + print(f"\n[Server] write_haiku called for topic '{topic}'") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + print(f"[Server] Task {task.task_id} starting sampling...") + + result = await task.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text=f"Write a haiku about {topic}"), + ) + ], + max_tokens=50, + ) + + haiku = "No response" + if isinstance(result.content, types.TextContent): + haiku = result.content.text + + print(f"[Server] Received sampling response: {haiku[:50]}...") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Haiku:\n{haiku}")]) + + return await ctx.experimental.run_task(work) + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + """Dispatch tool calls to their handlers.""" + if name == "confirm_delete": + return await handle_confirm_delete(arguments) + elif name == "write_haiku": + return await handle_write_haiku(arguments) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], + isError=True, + ) + + +def create_app(session_manager: StreamableHTTPSessionManager) -> Starlette: + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + lifespan=app_lifespan, + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +def main(port: int) -> int: + session_manager = StreamableHTTPSessionManager(app=server) + starlette_app = create_app(session_manager) + print(f"Starting server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 diff --git a/examples/servers/simple-task-interactive/pyproject.toml b/examples/servers/simple-task-interactive/pyproject.toml new file mode 100644 index 0000000000..492345ff52 --- /dev/null +++ b/examples/servers/simple-task-interactive/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-interactive" +version = "0.1.0" +description = "A simple MCP server demonstrating interactive tasks (elicitation & sampling)" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "elicitation", "sampling"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-task-interactive = "mcp_simple_task_interactive.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_interactive"] + +[tool.pyright] +include = ["mcp_simple_task_interactive"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task/README.md b/examples/servers/simple-task/README.md new file mode 100644 index 0000000000..6914e0414f --- /dev/null +++ b/examples/servers/simple-task/README.md @@ -0,0 +1,37 @@ +# Simple Task Server + +A minimal MCP server demonstrating the experimental tasks feature over streamable HTTP. + +## Running + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. + +## What it does + +This server exposes a single tool `long_running_task` that: + +1. Must be called as a task (with `task` metadata in the request) +2. Takes ~3 seconds to complete +3. Sends status updates during execution +4. Returns a result when complete + +## Usage with the client + +In one terminal, start the server: + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +In another terminal, run the client: + +```bash +cd examples/clients/simple-task-client +uv run mcp-simple-task-client +``` diff --git a/examples/servers/simple-task/mcp_simple_task/__init__.py b/examples/servers/simple-task/mcp_simple_task/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/servers/simple-task/mcp_simple_task/__main__.py b/examples/servers/simple-task/mcp_simple_task/__main__.py new file mode 100644 index 0000000000..e7ef16530b --- /dev/null +++ b/examples/servers/simple-task/mcp_simple_task/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py new file mode 100644 index 0000000000..d0681b8423 --- /dev/null +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -0,0 +1,84 @@ +"""Simple task server demonstrating MCP tasks over streamable HTTP.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import click +import mcp.types as types +import uvicorn +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + +server = Server("simple-task-server") + +# One-line setup: auto-registers get_task, get_task_result, list_tasks, cancel_task +server.experimental.enable_tasks() + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + inputSchema={"type": "object", "properties": {}}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ) + ] + + +async def handle_long_running_task(arguments: dict[str, Any]) -> types.CreateTaskResult: + """Handle the long_running_task tool - demonstrates status updates.""" + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Starting work...") + await anyio.sleep(1) + + await task.update_status("Processing step 1...") + await anyio.sleep(1) + + await task.update_status("Processing step 2...") + await anyio.sleep(1) + + return types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) + + return await ctx.experimental.run_task(work) + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + """Dispatch tool calls to their handlers.""" + if name == "long_running_task": + return await handle_long_running_task(arguments) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], + isError=True, + ) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +def main(port: int) -> int: + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + starlette_app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + lifespan=app_lifespan, + ) + + print(f"Starting server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 diff --git a/examples/servers/simple-task/pyproject.toml b/examples/servers/simple-task/pyproject.toml new file mode 100644 index 0000000000..a8fba8bdc1 --- /dev/null +++ b/examples/servers/simple-task/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task" +version = "0.1.0" +description = "A simple MCP server demonstrating tasks" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-task = "mcp_simple_task.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task"] + +[tool.pyright] +include = ["mcp_simple_task"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-tool/mcp_simple_tool/__init__.py b/examples/servers/simple-tool/mcp_simple_tool/__init__.py index 8b13789179..e69de29bb2 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/__init__.py +++ b/examples/servers/simple-tool/mcp_simple_tool/__init__.py @@ -1 +0,0 @@ - diff --git a/examples/servers/simple-tool/pyproject.toml b/examples/servers/simple-tool/pyproject.toml index 46d118cca4..c3944f3146 100644 --- a/examples/servers/simple-tool/pyproject.toml +++ b/examples/servers/simple-tool/pyproject.toml @@ -43,5 +43,5 @@ ignore = [] line-length = 120 target-version = "py310" -[tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/sse-polling-demo/README.md b/examples/servers/sse-polling-demo/README.md new file mode 100644 index 0000000000..e9d4446e1f --- /dev/null +++ b/examples/servers/sse-polling-demo/README.md @@ -0,0 +1,36 @@ +# MCP SSE Polling Demo Server + +Demonstrates the SSE polling pattern with server-initiated stream close for long-running tasks (SEP-1699). + +## Features + +- Priming events (automatic with EventStore) +- Server-initiated stream close via `close_sse_stream()` callback +- Client auto-reconnect with Last-Event-ID +- Progress notifications during long-running tasks +- Configurable retry interval + +## Usage + +```bash +# Start server on default port +uv run mcp-sse-polling-demo --port 3000 + +# Custom retry interval (milliseconds) +uv run mcp-sse-polling-demo --port 3000 --retry-interval 100 +``` + +## Tool: process_batch + +Processes items with periodic checkpoints that trigger SSE stream closes: + +- `items`: Number of items to process (1-100, default: 10) +- `checkpoint_every`: Close stream after this many items (1-20, default: 3) + +## Client + +Use the companion `mcp-sse-polling-client` to test: + +```bash +uv run mcp-sse-polling-client --url http://localhost:3000/mcp +``` diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py new file mode 100644 index 0000000000..46af2fdeed --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__init__.py @@ -0,0 +1 @@ +"""SSE Polling Demo Server - demonstrates close_sse_stream for long-running tasks.""" diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py new file mode 100644 index 0000000000..23cfc85e11 --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/__main__.py @@ -0,0 +1,6 @@ +"""Entry point for the SSE Polling Demo server.""" + +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py new file mode 100644 index 0000000000..75f98cdd49 --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/event_store.py @@ -0,0 +1,100 @@ +""" +In-memory event store for demonstrating resumability functionality. + +This is a simple implementation intended for examples and testing, +not for production use where a persistent storage solution would be more appropriate. +""" + +import logging +from collections import deque +from dataclasses import dataclass +from uuid import uuid4 + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +@dataclass +class EventEntry: + """Represents an event entry in the event store.""" + + event_id: EventId + stream_id: StreamId + message: JSONRPCMessage | None # None for priming events + + +class InMemoryEventStore(EventStore): + """ + Simple in-memory implementation of the EventStore interface for resumability. + This is primarily intended for examples and testing, not for production use + where a persistent storage solution would be more appropriate. + + This implementation keeps only the last N events per stream for memory efficiency. + """ + + def __init__(self, max_events_per_stream: int = 100): + """Initialize the event store. + + Args: + max_events_per_stream: Maximum number of events to keep per stream + """ + self.max_events_per_stream = max_events_per_stream + # for maintaining last N events per stream + self.streams: dict[StreamId, deque[EventEntry]] = {} + # event_id -> EventEntry for quick lookup + self.event_index: dict[EventId, EventEntry] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + """Stores an event with a generated event ID. + + Args: + stream_id: ID of the stream the event belongs to + message: The message to store, or None for priming events + """ + event_id = str(uuid4()) + event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) + + # Get or create deque for this stream + if stream_id not in self.streams: + self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) + + # If deque is full, the oldest event will be automatically removed + # We need to remove it from the event_index as well + if len(self.streams[stream_id]) == self.max_events_per_stream: + oldest_event = self.streams[stream_id][0] + self.event_index.pop(oldest_event.event_id, None) + + # Add new event + self.streams[stream_id].append(event_entry) + self.event_index[event_id] = event_entry + + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replays events that occurred after the specified event ID.""" + if last_event_id not in self.event_index: + logger.warning(f"Event ID {last_event_id} not found in store") + return None + + # Get the stream and find events after the last one + last_event = self.event_index[last_event_id] + stream_id = last_event.stream_id + stream_events = self.streams.get(last_event.stream_id, deque()) + + # Events in deque are already in chronological order + found_last = False + for event in stream_events: + if found_last: + # Skip priming events (None messages) during replay + if event.message is not None: + await send_callback(EventMessage(event.message, event.event_id)) + elif event.event_id == last_event_id: + found_last = True + + return stream_id diff --git a/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py new file mode 100644 index 0000000000..e4bdcaa396 --- /dev/null +++ b/examples/servers/sse-polling-demo/mcp_sse_polling_demo/server.py @@ -0,0 +1,177 @@ +""" +SSE Polling Demo Server + +Demonstrates the SSE polling pattern with close_sse_stream() for long-running tasks. + +Features demonstrated: +- Priming events (automatic with EventStore) +- Server-initiated stream close via close_sse_stream callback +- Client auto-reconnect with Last-Event-ID +- Progress notifications during long-running tasks + +Run with: + uv run mcp-sse-polling-demo --port 3000 +""" + +import contextlib +import logging +from collections.abc import AsyncIterator +from typing import Any + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +from .event_store import InMemoryEventStore + +logger = logging.getLogger(__name__) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR)", +) +@click.option( + "--retry-interval", + default=100, + help="SSE retry interval in milliseconds (sent to client)", +) +def main(port: int, log_level: str, retry_interval: int) -> int: + """Run the SSE Polling Demo server.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Create the lowlevel server + app = Server("sse-polling-demo") + + @app.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: + """Handle tool calls.""" + ctx = app.request_context + + if name == "process_batch": + items = arguments.get("items", 10) + checkpoint_every = arguments.get("checkpoint_every", 3) + + if items < 1 or items > 100: + return [types.TextContent(type="text", text="Error: items must be between 1 and 100")] + if checkpoint_every < 1 or checkpoint_every > 20: + return [types.TextContent(type="text", text="Error: checkpoint_every must be between 1 and 20")] + + await ctx.session.send_log_message( + level="info", + data=f"Starting batch processing of {items} items...", + logger="process_batch", + related_request_id=ctx.request_id, + ) + + for i in range(1, items + 1): + # Simulate work + await anyio.sleep(0.5) + + # Report progress + await ctx.session.send_log_message( + level="info", + data=f"[{i}/{items}] Processing item {i}", + logger="process_batch", + related_request_id=ctx.request_id, + ) + + # Checkpoint: close stream to trigger client reconnect + if i % checkpoint_every == 0 and i < items: + await ctx.session.send_log_message( + level="info", + data=f"Checkpoint at item {i} - closing SSE stream for polling", + logger="process_batch", + related_request_id=ctx.request_id, + ) + if ctx.close_sse_stream: + logger.info(f"Closing SSE stream at checkpoint {i}") + await ctx.close_sse_stream() + # Wait for client to reconnect (must be > retry_interval of 100ms) + await anyio.sleep(0.2) + + return [ + types.TextContent( + type="text", + text=f"Successfully processed {items} items with checkpoints every {checkpoint_every} items", + ) + ] + + return [types.TextContent(type="text", text=f"Unknown tool: {name}")] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + """List available tools.""" + return [ + types.Tool( + name="process_batch", + description=( + "Process a batch of items with periodic checkpoints. " + "Demonstrates SSE polling where server closes stream periodically." + ), + inputSchema={ + "type": "object", + "properties": { + "items": { + "type": "integer", + "description": "Number of items to process (1-100)", + "default": 10, + }, + "checkpoint_every": { + "type": "integer", + "description": "Close stream after this many items (1-20)", + "default": 3, + }, + }, + }, + ) + ] + + # Create event store for resumability + event_store = InMemoryEventStore() + + # Create session manager with event store and retry interval + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=event_store, + retry_interval=retry_interval, + ) + + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + logger.info(f"SSE Polling Demo server started on port {port}") + logger.info("Try: POST /mcp with tools/call for 'process_batch'") + yield + logger.info("Server shutting down...") + + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/servers/sse-polling-demo/pyproject.toml b/examples/servers/sse-polling-demo/pyproject.toml new file mode 100644 index 0000000000..f7ad89217c --- /dev/null +++ b/examples/servers/sse-polling-demo/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-sse-polling-demo" +version = "0.1.0" +description = "Demo server showing SSE polling with close_sse_stream for long-running tasks" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "sse", "polling", "streamable", "http"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.2.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-sse-polling-demo = "mcp_sse_polling_demo.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_sse_polling_demo"] + +[tool.pyright] +include = ["mcp_sse_polling_demo"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__init__.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__init__.py new file mode 100644 index 0000000000..c65905675b --- /dev/null +++ b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__init__.py @@ -0,0 +1 @@ +"""Example of structured output with low-level MCP server.""" diff --git a/examples/servers/structured_output_lowlevel.py b/examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py similarity index 100% rename from examples/servers/structured_output_lowlevel.py rename to examples/servers/structured-output-lowlevel/mcp_structured_output_lowlevel/__main__.py diff --git a/examples/servers/structured-output-lowlevel/pyproject.toml b/examples/servers/structured-output-lowlevel/pyproject.toml new file mode 100644 index 0000000000..554efc6145 --- /dev/null +++ b/examples/servers/structured-output-lowlevel/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "mcp-structured-output-lowlevel" +version = "0.1.0" +description = "Example of structured output with low-level MCP server" +requires-python = ">=3.10" +dependencies = ["mcp"] diff --git a/examples/snippets/clients/oauth_client.py b/examples/snippets/clients/oauth_client.py index 45026590a5..140b38aedb 100644 --- a/examples/snippets/clients/oauth_client.py +++ b/examples/snippets/clients/oauth_client.py @@ -10,11 +10,12 @@ import asyncio from urllib.parse import parse_qs, urlparse +import httpx from pydantic import AnyUrl from mcp import ClientSession from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -68,15 +69,16 @@ async def main(): callback_handler=handle_callback, ) - async with streamablehttp_client("http://localhost:8001/mcp", auth=oauth_auth) as (read, write, _): - async with ClientSession(read, write) as session: - await session.initialize() + async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client: + async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() - tools = await session.list_tools() - print(f"Available tools: {[tool.name for tool in tools.tools]}") + tools = await session.list_tools() + print(f"Available tools: {[tool.name for tool in tools.tools]}") - resources = await session.list_resources() - print(f"Available resources: {[r.uri for r in resources.resources]}") + resources = await session.list_resources() + print(f"Available resources: {[r.uri for r in resources.resources]}") def run(): diff --git a/examples/snippets/clients/pagination_client.py b/examples/snippets/clients/pagination_client.py index 4df1aec600..1805d2d315 100644 --- a/examples/snippets/clients/pagination_client.py +++ b/examples/snippets/clients/pagination_client.py @@ -6,7 +6,7 @@ from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.types import Resource +from mcp.types import PaginatedRequestParams, Resource async def list_all_resources() -> None: @@ -23,7 +23,7 @@ async def list_all_resources() -> None: while True: # Fetch a page of resources - result = await session.list_resources(cursor=cursor) + result = await session.list_resources(params=PaginatedRequestParams(cursor=cursor)) all_resources.extend(result.resources) print(f"Fetched {len(result.resources)} resources") diff --git a/examples/snippets/clients/streamable_basic.py b/examples/snippets/clients/streamable_basic.py index 108439613e..071ea81553 100644 --- a/examples/snippets/clients/streamable_basic.py +++ b/examples/snippets/clients/streamable_basic.py @@ -6,12 +6,12 @@ import asyncio from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client async def main(): # Connect to a streamable HTTP server - async with streamablehttp_client("http://localhost:8000/mcp") as ( + async with streamable_http_client("http://localhost:8000/mcp") as ( read_stream, write_stream, _, diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py new file mode 100644 index 0000000000..56457512c6 --- /dev/null +++ b/examples/snippets/clients/url_elicitation_client.py @@ -0,0 +1,318 @@ +"""URL Elicitation Client Example. + +Demonstrates how clients handle URL elicitation requests from servers. +This is the Python equivalent of TypeScript SDK's elicitationUrlExample.ts, +focused on URL elicitation patterns without OAuth complexity. + +Features demonstrated: +1. Client elicitation capability declaration +2. Handling elicitation requests from servers via callback +3. Catching UrlElicitationRequiredError from tool calls +4. Browser interaction with security warnings +5. Interactive CLI for testing + +Run with: + cd examples/snippets + uv run elicitation-client + +Requires a server with URL elicitation tools running. Start the elicitation +server first: + uv run server elicitation sse +""" + +from __future__ import annotations + +import asyncio +import json +import subprocess +import sys +import webbrowser +from typing import Any +from urllib.parse import urlparse + +from mcp import ClientSession, types +from mcp.client.sse import sse_client +from mcp.shared.context import RequestContext +from mcp.shared.exceptions import McpError, UrlElicitationRequiredError +from mcp.types import URL_ELICITATION_REQUIRED + + +async def handle_elicitation( + context: RequestContext[ClientSession, Any], + params: types.ElicitRequestParams, +) -> types.ElicitResult | types.ErrorData: + """Handle elicitation requests from the server. + + This callback is invoked when the server sends an elicitation/request. + For URL mode, we prompt the user and optionally open their browser. + """ + if params.mode == "url": + return await handle_url_elicitation(params) + else: + # We only support URL mode in this example + return types.ErrorData( + code=types.INVALID_REQUEST, + message=f"Unsupported elicitation mode: {params.mode}", + ) + + +async def handle_url_elicitation( + params: types.ElicitRequestParams, +) -> types.ElicitResult: + """Handle URL mode elicitation - show security warning and optionally open browser. + + This function demonstrates the security-conscious approach to URL elicitation: + 1. Display the full URL and domain for user inspection + 2. Show the server's reason for requesting this interaction + 3. Require explicit user consent before opening any URL + """ + # Extract URL parameters - these are available on URL mode requests + url = getattr(params, "url", None) + elicitation_id = getattr(params, "elicitationId", None) + message = params.message + + if not url: + print("Error: No URL provided in elicitation request") + return types.ElicitResult(action="http://github.com/cancel") + + # Extract domain for security display + domain = extract_domain(url) + + # Security warning - always show the user what they're being asked to do + print("\n" + "=" * 60) + print("SECURITY WARNING: External URL Request") + print("=" * 60) + print("\nThe server is requesting you to open an external URL.") + print(f"\n Domain: {domain}") + print(f" Full URL: {url}") + print("\n Server's reason:") + print(f" {message}") + print(f"\n Elicitation ID: {elicitation_id}") + print("\n" + "-" * 60) + + # Get explicit user consent + try: + response = input("\nOpen this URL in your browser? (y/n): ").strip().lower() + except EOFError: + return types.ElicitResult(action="http://github.com/cancel") + + if response in ("n", "no"): + print("URL navigation declined.") + return types.ElicitResult(action="http://github.com/decline") + elif response not in ("y", "yes"): + print("Invalid response. Cancelling.") + return types.ElicitResult(action="http://github.com/cancel") + + # Open the browser + print(f"\nOpening browser to: {url}") + open_browser(url) + + print("Waiting for you to complete the interaction in your browser...") + print("(The server will continue once you've finished)") + + return types.ElicitResult(action="http://github.com/accept") + + +def extract_domain(url: str) -> str: + """Extract domain from URL for security display.""" + try: + return urlparse(url).netloc + except Exception: + return "unknown" + + +def open_browser(url: str) -> None: + """Open URL in the default browser.""" + try: + if sys.platform == "darwin": + subprocess.run(["open", url], check=False) + elif sys.platform == "win32": + subprocess.run(["start", url], shell=True, check=False) + else: + webbrowser.open(url) + except Exception as e: + print(f"Failed to open browser: {e}") + print(f"Please manually open: {url}") + + +async def call_tool_with_error_handling( + session: ClientSession, + tool_name: str, + arguments: dict[str, Any], +) -> types.CallToolResult | None: + """Call a tool, handling UrlElicitationRequiredError if raised. + + When a server tool needs URL elicitation before it can proceed, + it can either: + 1. Send an elicitation request directly (handled by elicitation_callback) + 2. Return an error with code -32042 (URL_ELICITATION_REQUIRED) + + This function demonstrates handling case 2 - catching the error + and processing the required URL elicitations. + """ + try: + result = await session.call_tool(tool_name, arguments) + + # Check if the tool returned an error in the result + if result.isError: + print(f"Tool returned error: {result.content}") + return None + + return result + + except McpError as e: + # Check if this is a URL elicitation required error + if e.error.code == URL_ELICITATION_REQUIRED: + print("\n[Tool requires URL elicitation to proceed]") + + # Convert to typed error to access elicitations + url_error = UrlElicitationRequiredError.from_error(e.error) + + # Process each required elicitation + for elicitation in url_error.elicitations: + await handle_url_elicitation(elicitation) + + return None + else: + # Re-raise other MCP errors + print(f"MCP Error: {e.error.message} (code: {e.error.code})") + return None + + +def print_help() -> None: + """Print available commands.""" + print("\nAvailable commands:") + print(" list-tools - List available tools") + print(" call [json-args] - Call a tool with optional JSON arguments") + print(" secure-payment - Test URL elicitation via ctx.elicit_url()") + print(" connect-service - Test URL elicitation via UrlElicitationRequiredError") + print(" help - Show this help") + print(" quit - Exit the program") + + +def print_tool_result(result: types.CallToolResult | None) -> None: + """Print a tool call result.""" + if not result: + return + print("\nTool result:") + for content in result.content: + if isinstance(content, types.TextContent): + print(f" {content.text}") + else: + print(f" [{content.type}]") + + +async def handle_list_tools(session: ClientSession) -> None: + """Handle the list-tools command.""" + tools = await session.list_tools() + if tools.tools: + print("\nAvailable tools:") + for tool in tools.tools: + print(f" - {tool.name}: {tool.description or 'No description'}") + else: + print("No tools available") + + +async def handle_call_command(session: ClientSession, command: str) -> None: + """Handle the call command.""" + parts = command.split(maxsplit=2) + if len(parts) < 2: + print("Usage: call [json-args]") + return + + tool_name = parts[1] + args: dict[str, Any] = {} + if len(parts) > 2: + try: + args = json.loads(parts[2]) + except json.JSONDecodeError as e: + print(f"Invalid JSON arguments: {e}") + return + + print(f"\nCalling tool '{tool_name}' with args: {args}") + result = await call_tool_with_error_handling(session, tool_name, args) + print_tool_result(result) + + +async def process_command(session: ClientSession, command: str) -> bool: + """Process a single command. Returns False if should exit.""" + if command in {"quit", "exit"}: + print("Goodbye!") + return False + + if command == "help": + print_help() + elif command == "list-tools": + await handle_list_tools(session) + elif command.startswith("call "): + await handle_call_command(session, command) + elif command == "secure-payment": + print("\nTesting secure_payment tool (uses ctx.elicit_url())...") + result = await call_tool_with_error_handling(session, "secure_payment", {"amount": 99.99}) + print_tool_result(result) + elif command == "connect-service": + print("\nTesting connect_service tool (raises UrlElicitationRequiredError)...") + result = await call_tool_with_error_handling(session, "connect_service", {"service_name": "github"}) + print_tool_result(result) + else: + print(f"Unknown command: {command}") + print("Type 'help' for available commands.") + + return True + + +async def run_command_loop(session: ClientSession) -> None: + """Run the interactive command loop.""" + while True: + try: + command = input("> ").strip() + except EOFError: + break + except KeyboardInterrupt: + print("\n") + break + + if not command: + continue + + if not await process_command(session, command): + break + + +async def main() -> None: + """Run the interactive URL elicitation client.""" + server_url = "http://localhost:8000/sse" + + print("=" * 60) + print("URL Elicitation Client Example") + print("=" * 60) + print(f"\nConnecting to: {server_url}") + print("(Start server with: cd examples/snippets && uv run server elicitation sse)") + + try: + async with sse_client(server_url) as (read, write): + async with ClientSession( + read, + write, + elicitation_callback=handle_elicitation, + ) as session: + await session.initialize() + print("\nConnected! Type 'help' for available commands.\n") + await run_command_loop(session) + + except ConnectionRefusedError: + print(f"\nError: Could not connect to {server_url}") + print("Make sure the elicitation server is running:") + print(" cd examples/snippets && uv run server elicitation sse") + except Exception as e: + print(f"\nError: {e}") + raise + + +def run() -> None: + """Entry point for the client script.""" + asyncio.run(main()) + + +if __name__ == "__main__": + run() diff --git a/examples/snippets/pyproject.toml b/examples/snippets/pyproject.toml index 76791a55a7..4e68846a09 100644 --- a/examples/snippets/pyproject.toml +++ b/examples/snippets/pyproject.toml @@ -21,3 +21,4 @@ completion-client = "clients.completion_client:main" direct-execution-server = "servers.direct_execution:main" display-utilities-client = "clients.display_utilities:main" oauth-client = "clients.oauth_client:run" +elicitation-client = "clients.url_elicitation_client:run" diff --git a/examples/snippets/servers/direct_call_tool_result.py b/examples/snippets/servers/direct_call_tool_result.py new file mode 100644 index 0000000000..54d49b2f66 --- /dev/null +++ b/examples/snippets/servers/direct_call_tool_result.py @@ -0,0 +1,42 @@ +"""Example showing direct CallToolResult return for advanced control.""" + +from typing import Annotated + +from pydantic import BaseModel + +from mcp.server.fastmcp import FastMCP +from mcp.types import CallToolResult, TextContent + +mcp = FastMCP("CallToolResult Example") + + +class ValidationModel(BaseModel): + """Model for validating structured output.""" + + status: str + data: dict[str, int] + + +@mcp.tool() +def advanced_tool() -> CallToolResult: + """Return CallToolResult directly for full control including _meta field.""" + return CallToolResult( + content=[TextContent(type="text", text="Response visible to the model")], + _meta={"hidden": "data for client applications only"}, + ) + + +@mcp.tool() +def validated_tool() -> Annotated[CallToolResult, ValidationModel]: + """Return CallToolResult with structured output validation.""" + return CallToolResult( + content=[TextContent(type="text", text="Validated response")], + structuredContent={"status": "success", "data": {"result": 42}}, + _meta={"internal": "metadata"}, + ) + + +@mcp.tool() +def empty_result_tool() -> CallToolResult: + """For empty results, return CallToolResult with empty content.""" + return CallToolResult(content=[]) diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 2c8a3b35ac..a1a65fb32c 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,18 @@ +"""Elicitation examples demonstrating form and URL mode elicitation. + +Form mode elicitation collects structured, non-sensitive data through a schema. +URL mode elicitation directs users to external URLs for sensitive operations +like OAuth flows, credential collection, or payment processing. +""" + +import uuid + from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.types import ElicitRequestURLParams mcp = FastMCP(name="Elicitation Example") @@ -18,7 +29,10 @@ class BookingPreferences(BaseModel): @mcp.tool() async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: - """Book a table with date availability check.""" + """Book a table with date availability check. + + This demonstrates form mode elicitation for collecting non-sensitive user input. + """ # Check if date is available if date == "2024-12-25": # Date unavailable - ask user for alternative @@ -35,3 +49,51 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS # Date available return f"[SUCCESS] Booked for {date} at {time}" + + +@mcp.tool() +async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str: + """Process a secure payment requiring URL confirmation. + + This demonstrates URL mode elicitation using ctx.elicit_url() for + operations that require out-of-band user interaction. + """ + elicitation_id = str(uuid.uuid4()) + + result = await ctx.elicit_url( + message=f"Please confirm payment of ${amount:.2f}", + url=f"https://payments.example.com/confirm?amount={amount}&id={elicitation_id}", + elicitation_id=elicitation_id, + ) + + if result.action == "accept": + # In a real app, the payment confirmation would happen out-of-band + # and you'd verify the payment status from your backend + return f"Payment of ${amount:.2f} initiated - check your browser to complete" + elif result.action == "decline": + return "Payment declined by user" + return "Payment cancelled" + + +@mcp.tool() +async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + """Connect to a third-party service requiring OAuth authorization. + + This demonstrates the "throw error" pattern using UrlElicitationRequiredError. + Use this pattern when the tool cannot proceed without user authorization. + """ + elicitation_id = str(uuid.uuid4()) + + # Raise UrlElicitationRequiredError to signal that the client must complete + # a URL elicitation before this request can be processed. + # The MCP framework will convert this to a -32042 error response. + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message=f"Authorization required to connect to {service_name}", + url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}", + elicitationId=elicitation_id, + ) + ] + ) diff --git a/examples/snippets/servers/fastmcp_quickstart.py b/examples/snippets/servers/fastmcp_quickstart.py index d7aef8c610..931cd263f8 100644 --- a/examples/snippets/servers/fastmcp_quickstart.py +++ b/examples/snippets/servers/fastmcp_quickstart.py @@ -1,14 +1,14 @@ """ FastMCP quickstart example. -cd to the `examples/snippets/clients` directory and run: - uv run server fastmcp_quickstart stdio +Run from the repository root: + uv run examples/snippets/servers/fastmcp_quickstart.py """ from mcp.server.fastmcp import FastMCP # Create an MCP server -mcp = FastMCP("Demo") +mcp = FastMCP("Demo", json_response=True) # Add an addition tool @@ -36,3 +36,8 @@ def greet_user(name: str, style: str = "friendly") -> str: } return f"{styles.get(style, styles['friendly'])} for someone named {name}." + + +# Run with streamable HTTP transport +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/examples/snippets/servers/lowlevel/direct_call_tool_result.py b/examples/snippets/servers/lowlevel/direct_call_tool_result.py new file mode 100644 index 0000000000..496eaad105 --- /dev/null +++ b/examples/snippets/servers/lowlevel/direct_call_tool_result.py @@ -0,0 +1,65 @@ +""" +Run from the repository root: + uv run examples/snippets/servers/lowlevel/direct_call_tool_result.py +""" + +import asyncio +from typing import Any + +import mcp.server.stdio +import mcp.types as types +from mcp.server.lowlevel import NotificationOptions, Server +from mcp.server.models import InitializationOptions + +server = Server("example-server") + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + """List available tools.""" + return [ + types.Tool( + name="advanced_tool", + description="Tool with full control including _meta field", + inputSchema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ) + ] + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult: + """Handle tool calls by returning CallToolResult directly.""" + if name == "advanced_tool": + message = str(arguments.get("message", "")) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Processed: {message}")], + structuredContent={"result": "success", "message": message}, + _meta={"hidden": "data for client applications only"}, + ) + + raise ValueError(f"Unknown tool: {name}") + + +async def run(): + """Run the server.""" + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="example", + server_version="0.1.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/examples/snippets/servers/oauth_server.py b/examples/snippets/servers/oauth_server.py index bd317e1ae5..3717c66de8 100644 --- a/examples/snippets/servers/oauth_server.py +++ b/examples/snippets/servers/oauth_server.py @@ -20,6 +20,7 @@ async def verify_token(self, token: str) -> AccessToken | None: # Create FastMCP instance as a Resource Server mcp = FastMCP( "Weather Service", + json_response=True, # Token verifier for authentication token_verifier=SimpleTokenVerifier(), # Auth settings for RFC 9728 Protected Resource Metadata diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 0099836c28..ae78a74ace 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -20,6 +20,7 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: max_tokens=100, ) + # Since we're not passing tools param, result.content is single content if result.content.type == "text": return result.content.text return str(result.content) diff --git a/examples/snippets/servers/streamable_config.py b/examples/snippets/servers/streamable_config.py index e265f6381b..d351a45d86 100644 --- a/examples/snippets/servers/streamable_config.py +++ b/examples/snippets/servers/streamable_config.py @@ -5,15 +5,15 @@ from mcp.server.fastmcp import FastMCP -# Stateful server (maintains session state) -mcp = FastMCP("StatefulServer") +# Stateless server with JSON responses (recommended) +mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) # Other configuration options: -# Stateless server (no session persistence) +# Stateless server with SSE streaming responses # mcp = FastMCP("StatelessServer", stateless_http=True) -# Stateless server (no session persistence, no sse stream with supported client) -# mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) +# Stateful server with session persistence +# mcp = FastMCP("StatefulServer") # Add a simple tool to demonstrate the server diff --git a/examples/snippets/servers/streamable_http_basic_mounting.py b/examples/snippets/servers/streamable_http_basic_mounting.py index abcc0e572c..74aa36ed4f 100644 --- a/examples/snippets/servers/streamable_http_basic_mounting.py +++ b/examples/snippets/servers/streamable_http_basic_mounting.py @@ -5,13 +5,15 @@ uvicorn examples.snippets.servers.streamable_http_basic_mounting:app --reload """ +import contextlib + from starlette.applications import Starlette from starlette.routing import Mount from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = FastMCP("My App") +mcp = FastMCP("My App", json_response=True) @mcp.tool() @@ -20,9 +22,17 @@ def hello() -> str: return "Hello from MCP!" +# Create a lifespan context manager to run the session manager +@contextlib.asynccontextmanager +async def lifespan(app: Starlette): + async with mcp.session_manager.run(): + yield + + # Mount the StreamableHTTP server to the existing ASGI server app = Starlette( routes=[ Mount("/", app=mcp.streamable_http_app()), - ] + ], + lifespan=lifespan, ) diff --git a/examples/snippets/servers/streamable_http_host_mounting.py b/examples/snippets/servers/streamable_http_host_mounting.py index d48558cc8e..3ae9d341e1 100644 --- a/examples/snippets/servers/streamable_http_host_mounting.py +++ b/examples/snippets/servers/streamable_http_host_mounting.py @@ -5,13 +5,15 @@ uvicorn examples.snippets.servers.streamable_http_host_mounting:app --reload """ +import contextlib + from starlette.applications import Starlette from starlette.routing import Host from mcp.server.fastmcp import FastMCP # Create MCP server -mcp = FastMCP("MCP Host App") +mcp = FastMCP("MCP Host App", json_response=True) @mcp.tool() @@ -20,9 +22,17 @@ def domain_info() -> str: return "This is served from mcp.acme.corp" +# Create a lifespan context manager to run the session manager +@contextlib.asynccontextmanager +async def lifespan(app: Starlette): + async with mcp.session_manager.run(): + yield + + # Mount using Host-based routing app = Starlette( routes=[ Host("mcp.acme.corp", app=mcp.streamable_http_app()), - ] + ], + lifespan=lifespan, ) diff --git a/examples/snippets/servers/streamable_http_multiple_servers.py b/examples/snippets/servers/streamable_http_multiple_servers.py index df347b7b30..8d0a1018d2 100644 --- a/examples/snippets/servers/streamable_http_multiple_servers.py +++ b/examples/snippets/servers/streamable_http_multiple_servers.py @@ -5,14 +5,16 @@ uvicorn examples.snippets.servers.streamable_http_multiple_servers:app --reload """ +import contextlib + from starlette.applications import Starlette from starlette.routing import Mount from mcp.server.fastmcp import FastMCP # Create multiple MCP servers -api_mcp = FastMCP("API Server") -chat_mcp = FastMCP("Chat Server") +api_mcp = FastMCP("API Server", json_response=True) +chat_mcp = FastMCP("Chat Server", json_response=True) @api_mcp.tool() @@ -32,10 +34,21 @@ def send_message(message: str) -> str: api_mcp.settings.streamable_http_path = "/" chat_mcp.settings.streamable_http_path = "/" + +# Create a combined lifespan to manage both session managers +@contextlib.asynccontextmanager +async def lifespan(app: Starlette): + async with contextlib.AsyncExitStack() as stack: + await stack.enter_async_context(api_mcp.session_manager.run()) + await stack.enter_async_context(chat_mcp.session_manager.run()) + yield + + # Mount the servers app = Starlette( routes=[ Mount("/api", app=api_mcp.streamable_http_app()), Mount("/chat", app=chat_mcp.streamable_http_app()), - ] + ], + lifespan=lifespan, ) diff --git a/examples/snippets/servers/streamable_http_path_config.py b/examples/snippets/servers/streamable_http_path_config.py index 71228423ea..9fabf12fa7 100644 --- a/examples/snippets/servers/streamable_http_path_config.py +++ b/examples/snippets/servers/streamable_http_path_config.py @@ -12,7 +12,11 @@ # Configure streamable_http_path during initialization # This server will mount at the root of wherever it's mounted -mcp_at_root = FastMCP("My Server", streamable_http_path="/") +mcp_at_root = FastMCP( + "My Server", + json_response=True, + streamable_http_path="/", +) @mcp_at_root.tool() diff --git a/examples/snippets/servers/streamable_starlette_mount.py b/examples/snippets/servers/streamable_starlette_mount.py index 57d2d2ea5b..b3a630b0f5 100644 --- a/examples/snippets/servers/streamable_starlette_mount.py +++ b/examples/snippets/servers/streamable_starlette_mount.py @@ -11,7 +11,7 @@ from mcp.server.fastmcp import FastMCP # Create the Echo server -echo_mcp = FastMCP(name="EchoServer", stateless_http=True) +echo_mcp = FastMCP(name="EchoServer", stateless_http=True, json_response=True) @echo_mcp.tool() @@ -21,7 +21,7 @@ def echo(message: str) -> str: # Create the Math server -math_mcp = FastMCP(name="MathServer", stateless_http=True) +math_mcp = FastMCP(name="MathServer", stateless_http=True, json_response=True) @math_mcp.tool() diff --git a/mkdocs.yml b/mkdocs.yml index 18cbb034bb..22c323d9d4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -18,6 +18,12 @@ nav: - Low-Level Server: low-level-server.md - Authorization: authorization.md - Testing: testing.md + - Experimental: + - Overview: experimental/index.md + - Tasks: + - Introduction: experimental/tasks.md + - Server Implementation: experimental/tasks-server.md + - Client Usage: experimental/tasks-client.md - API Reference: api.md theme: diff --git a/pyproject.toml b/pyproject.toml index 5af7ff4d8a..078a1dfdcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dependencies = [ "uvicorn>=0.31.1; sys_platform != 'emscripten'", "jsonschema>=4.20.0", "pywin32>=310; sys_platform == 'win32'", + "pyjwt[crypto]>=2.10.1", + "typing-extensions>=4.9.0", + "typing-inspection>=0.4.1", ] [project.optional-dependencies] @@ -45,7 +48,7 @@ mcp = "mcp.cli:app [cli]" [tool.uv] default-groups = ["dev", "docs"] -required-version = ">=0.7.2" +required-version = ">=0.9.5" [dependency-groups] dev = [ @@ -59,6 +62,7 @@ dev = [ "pytest-pretty>=1.2.0", "inline-snapshot>=0.23.0", "dirty-equals>=0.9.0", + "coverage[toml]==7.10.7", ] docs = [ "mkdocs>=1.6.1", @@ -98,8 +102,8 @@ venv = ".venv" # those private functions instead of testing the private functions directly. It makes it easier to maintain the code source # and refactor code that is not public. executionEnvironments = [ - { root = "tests", reportUnusedFunction = false, reportPrivateUsage = false }, - { root = "examples/servers", reportUnusedFunction = false }, + { root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false }, + { root = "examples/servers", reportUnusedFunction = false }, ] [tool.ruff] @@ -109,17 +113,17 @@ extend-exclude = ["README.md"] [tool.ruff.lint] select = [ - "C4", # flake8-comprehensions - "C90", # mccabe - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "PERF", # Perflint - "PL", # Pylint - "UP", # pyupgrade + "C4", # flake8-comprehensions + "C90", # mccabe + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "PERF", # Perflint + "PL", # Pylint + "UP", # pyupgrade ] ignore = ["PERF203", "PLC0415", "PLR0402"] -mccabe.max-complexity = 24 # Default is 10 +mccabe.max-complexity = 24 # Default is 10 [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] @@ -128,13 +132,13 @@ mccabe.max-complexity = 24 # Default is 10 [tool.ruff.lint.pylint] allow-magic-value-types = ["bytes", "float", "int", "str"] -max-args = 23 # Default is 5 -max-branches = 23 # Default is 12 -max-returns = 13 # Default is 6 -max-statements = 102 # Default is 50 +max-args = 23 # Default is 5 +max-branches = 23 # Default is 12 +max-returns = 13 # Default is 6 +max-statements = 102 # Default is 50 [tool.uv.workspace] -members = ["examples/servers/*", "examples/snippets"] +members = ["examples/clients/*", "examples/servers/*", "examples/snippets"] [tool.uv.sources] mcp = { workspace = true } @@ -154,16 +158,53 @@ filterwarnings = [ "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", # pywin32 internal deprecation warning - "ignore:getargs.*The 'u' format is deprecated:DeprecationWarning" + "ignore:getargs.*The 'u' format is deprecated:DeprecationWarning", ] [tool.markdown.lint] -default=true -MD004=false # ul-style - Unordered list style -MD007.indent=2 # ul-indent - Unordered list indentation -MD013=false # line-length - Line length -MD029=false # ol-prefix - Ordered list item prefix -MD033=false # no-inline-html Inline HTML -MD041=false # first-line-heading/first-line-h1 -MD046=false # indented-code-blocks -MD059=false # descriptive-link-text +default = true +MD004 = false # ul-style - Unordered list style +MD007.indent = 2 # ul-indent - Unordered list indentation +MD013 = false # line-length - Line length +MD029 = false # ol-prefix - Ordered list item prefix +MD033 = false # no-inline-html Inline HTML +MD041 = false # first-line-heading/first-line-h1 +MD046 = false # indented-code-blocks +MD059 = false # descriptive-link-text + +# https://coverage.readthedocs.io/en/latest/config.html#run +[tool.coverage.run] +branch = true +patch = ["subprocess"] +concurrency = ["multiprocessing", "thread"] +source = ["src", "tests"] +relative_files = true +omit = [ + "src/mcp/client/__main__.py", + "src/mcp/server/__main__.py", + "src/mcp/os/posix/utilities.py", + "src/mcp/os/win32/utilities.py", +] + +# https://coverage.readthedocs.io/en/latest/config.html#report +[tool.coverage.report] +fail_under = 100 +skip_covered = true +show_missing = true +ignore_errors = true +precision = 2 +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "@overload", + "raise NotImplementedError", + "^\\s*\\.\\.\\.\\s*$", +] + +# https://coverage.readthedocs.io/en/latest/config.html#paths +[tool.coverage.paths] +source = [ + "src/", + "/home/runner/work/python-sdk/python-sdk/src/", + 'D:\a\python-sdk\python-sdk\src', +] diff --git a/scripts/test b/scripts/test new file mode 100755 index 0000000000..0d08e47b1b --- /dev/null +++ b/scripts/test @@ -0,0 +1,7 @@ +#!/bin/sh + +set -ex + +uv run --frozen coverage run -m pytest -n auto $@ +uv run --frozen coverage combine +uv run --frozen coverage report diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index e93b95c902..fbec40d0a9 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -3,7 +3,7 @@ from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession from .server.stdio import stdio_server -from .shared.exceptions import McpError +from .shared.exceptions import McpError, UrlElicitationRequiredError from .types import ( CallToolRequest, ClientCapabilities, @@ -13,6 +13,7 @@ CompleteRequest, CreateMessageRequest, CreateMessageResult, + CreateMessageResultWithTools, ErrorData, GetPromptRequest, GetPromptResult, @@ -41,7 +42,12 @@ ResourcesCapability, ResourceUpdatedNotification, RootsCapability, + SamplingCapability, + SamplingContent, + SamplingContextCapability, SamplingMessage, + SamplingMessageContentBlock, + SamplingToolsCapability, ServerCapabilities, ServerNotification, ServerRequest, @@ -50,7 +56,10 @@ StopReason, SubscribeRequest, Tool, + ToolChoice, + ToolResultContent, ToolsCapability, + ToolUseContent, UnsubscribeRequest, ) from .types import ( @@ -65,8 +74,10 @@ "ClientResult", "ClientSession", "ClientSessionGroup", + "CompleteRequest", "CreateMessageRequest", "CreateMessageResult", + "CreateMessageResultWithTools", "ErrorData", "GetPromptRequest", "GetPromptResult", @@ -77,6 +88,7 @@ "InitializedNotification", "JSONRPCError", "JSONRPCRequest", + "JSONRPCResponse", "ListPromptsRequest", "ListPromptsResult", "ListResourcesRequest", @@ -91,12 +103,17 @@ "PromptsCapability", "ReadResourceRequest", "ReadResourceResult", + "Resource", "ResourcesCapability", "ResourceUpdatedNotification", - "Resource", "RootsCapability", + "SamplingCapability", + "SamplingContent", + "SamplingContextCapability", "SamplingMessage", + "SamplingMessageContentBlock", "SamplingRole", + "SamplingToolsCapability", "ServerCapabilities", "ServerNotification", "ServerRequest", @@ -107,10 +124,12 @@ "StopReason", "SubscribeRequest", "Tool", + "ToolChoice", + "ToolResultContent", "ToolsCapability", + "ToolUseContent", "UnsubscribeRequest", + "UrlElicitationRequiredError", "stdio_client", "stdio_server", - "CompleteRequest", - "JSONRPCResponse", ] diff --git a/src/mcp/cli/__init__.py b/src/mcp/cli/__init__.py index 3ef56d8063..b29bce8878 100644 --- a/src/mcp/cli/__init__.py +++ b/src/mcp/cli/__init__.py @@ -2,5 +2,5 @@ from .cli import app -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover app() diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 6a2effa3be..f2dc6888a1 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -14,7 +14,7 @@ MCP_PACKAGE = "mcp[cli]" -def get_claude_config_path() -> Path | None: +def get_claude_config_path() -> Path | None: # pragma: no cover """Get the Claude config directory based on platform.""" if sys.platform == "win32": path = Path(Path.home(), "AppData", "Roaming", "Claude") @@ -33,7 +33,7 @@ def get_claude_config_path() -> Path | None: def get_uv_path() -> str: """Get the full path to the uv executable.""" uv_path = shutil.which("uv") - if not uv_path: + if not uv_path: # pragma: no cover logger.error( "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" ) @@ -65,14 +65,14 @@ def update_claude_config( """ config_dir = get_claude_config_path() uv_path = get_uv_path() - if not config_dir: + if not config_dir: # pragma: no cover raise RuntimeError( "Claude Desktop config directory not found. Please ensure Claude Desktop" " is installed and has been run at least once to initialize its config." ) config_file = config_dir / "claude_desktop_config.json" - if not config_file.exists(): + if not config_file.exists(): # pragma: no cover try: config_file.write_text("{}") except Exception: @@ -90,7 +90,7 @@ def update_claude_config( config["mcpServers"] = {} # Always preserve existing env vars and merge with new ones - if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: # pragma: no cover existing_env = config["mcpServers"][server_name]["env"] if env_vars: # New vars take precedence over existing ones @@ -99,18 +99,18 @@ def update_claude_config( env_vars = existing_env # Build uv run command - args = ["run"] + args = ["run", "--frozen"] # Collect all packages in a set to deduplicate packages = {MCP_PACKAGE} - if with_packages: + if with_packages: # pragma: no cover packages.update(pkg for pkg in with_packages if pkg) # Add all packages with --with for pkg in sorted(packages): args.extend(["--with", pkg]) - if with_editable: + if with_editable: # pragma: no cover args.extend(["--with-editable", str(with_editable)]) # Convert file path to absolute before adding to command @@ -118,7 +118,7 @@ def update_claude_config( if ":" in file_spec: file_path, server_object = file_spec.rsplit(":", 1) file_spec = f"{Path(file_path).resolve()}:{server_object}" - else: + else: # pragma: no cover file_spec = str(Path(file_spec).resolve()) # Add fastmcp run command @@ -127,7 +127,7 @@ def update_claude_config( server_config: dict[str, Any] = {"command": uv_path, "args": args} # Add environment variables if specified - if env_vars: + if env_vars: # pragma: no cover server_config["env"] = env_vars config["mcpServers"][server_name] = server_config @@ -138,7 +138,7 @@ def update_claude_config( extra={"config_file": str(config_file)}, ) return True - except Exception: + except Exception: # pragma: no cover logger.exception( "Failed to update Claude config", extra={ diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index 4a7257a117..c4cae0dce3 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -13,20 +13,20 @@ try: import typer -except ImportError: +except ImportError: # pragma: no cover print("Error: typer is required. Install with 'pip install mcp[cli]'") sys.exit(1) try: from mcp.cli import claude from mcp.server.fastmcp.utilities.logging import get_logger -except ImportError: +except ImportError: # pragma: no cover print("Error: mcp.server.fastmcp is not installed or not in PYTHONPATH") sys.exit(1) try: import dotenv -except ImportError: +except ImportError: # pragma: no cover dotenv = None logger = get_logger("cli") @@ -53,7 +53,7 @@ def _get_npx_command(): return "npx" # On Unix-like systems, just use npx -def _parse_env_var(env_var: str) -> tuple[str, str]: +def _parse_env_var(env_var: str) -> tuple[str, str]: # pragma: no cover """Parse environment variable string in format KEY=VALUE.""" if "=" not in env_var: logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE") @@ -67,7 +67,7 @@ def _build_uv_command( with_editable: Path | None = None, with_packages: list[str] | None = None, ) -> list[str]: - """Build the uv run command that runs a MCP server through mcp run.""" + """Build the uv run command that runs an MCP server through mcp run.""" cmd = ["uv"] cmd.extend(["run", "--with", "mcp"]) @@ -77,7 +77,7 @@ def _build_uv_command( if with_packages: for pkg in with_packages: - if pkg: + if pkg: # pragma: no cover cmd.extend(["--with", pkg]) # Add mcp run command @@ -116,8 +116,8 @@ def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: return file_path, server_object -def _import_server(file: Path, server_object: str | None = None): - """Import a MCP server from a file. +def _import_server(file: Path, server_object: str | None = None): # pragma: no cover + """Import an MCP server from a file. Args: file: Path to the file @@ -209,7 +209,7 @@ def _check_server_object(server_object: Any, object_name: str): @app.command() -def version() -> None: +def version() -> None: # pragma: no cover """Show the MCP version.""" try: version = importlib.metadata.version("mcp") @@ -243,8 +243,8 @@ def dev( help="Additional packages to install", ), ] = [], -) -> None: - """Run a MCP server with the MCP Inspector.""" +) -> None: # pragma: no cover + """Run an MCP server with the MCP Inspector.""" file, server_object = _parse_file_path(file_spec) logger.debug( @@ -316,8 +316,8 @@ def run( help="Transport protocol to use (stdio or sse)", ), ] = None, -) -> None: - """Run a MCP server. +) -> None: # pragma: no cover + """Run an MCP server. The server can be specified in two ways:\n 1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n @@ -411,8 +411,8 @@ def install( resolve_path=True, ), ] = None, -) -> None: - """Install a MCP server in the Claude desktop app. +) -> None: # pragma: no cover + """Install an MCP server in the Claude desktop app. Environment variables are preserved once added and only updated if new values are explicitly provided. diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py deleted file mode 100644 index 376036e8cf..0000000000 --- a/src/mcp/client/auth.py +++ /dev/null @@ -1,551 +0,0 @@ -""" -OAuth2 Authentication implementation for HTTPX. - -Implements authorization code flow with PKCE and automatic token refresh. -""" - -import base64 -import hashlib -import logging -import re -import secrets -import string -import time -from collections.abc import AsyncGenerator, Awaitable, Callable -from dataclasses import dataclass, field -from typing import Protocol -from urllib.parse import urlencode, urljoin, urlparse - -import anyio -import httpx -from pydantic import BaseModel, Field, ValidationError - -from mcp.client.streamable_http import MCP_PROTOCOL_VERSION -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthMetadata, - OAuthToken, - ProtectedResourceMetadata, -) -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url -from mcp.types import LATEST_PROTOCOL_VERSION - -logger = logging.getLogger(__name__) - - -class OAuthFlowError(Exception): - """Base exception for OAuth flow errors.""" - - -class OAuthTokenError(OAuthFlowError): - """Raised when token operations fail.""" - - -class OAuthRegistrationError(OAuthFlowError): - """Raised when client registration fails.""" - - -class PKCEParameters(BaseModel): - """PKCE (Proof Key for Code Exchange) parameters.""" - - code_verifier: str = Field(..., min_length=43, max_length=128) - code_challenge: str = Field(..., min_length=43, max_length=128) - - @classmethod - def generate(cls) -> "PKCEParameters": - """Generate new PKCE parameters.""" - code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) - digest = hashlib.sha256(code_verifier.encode()).digest() - code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") - return cls(code_verifier=code_verifier, code_challenge=code_challenge) - - -class TokenStorage(Protocol): - """Protocol for token storage implementations.""" - - async def get_tokens(self) -> OAuthToken | None: - """Get stored tokens.""" - ... - - async def set_tokens(self, tokens: OAuthToken) -> None: - """Store tokens.""" - ... - - async def get_client_info(self) -> OAuthClientInformationFull | None: - """Get stored client information.""" - ... - - async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: - """Store client information.""" - ... - - -@dataclass -class OAuthContext: - """OAuth flow context.""" - - server_url: str - client_metadata: OAuthClientMetadata - storage: TokenStorage - redirect_handler: Callable[[str], Awaitable[None]] - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] - timeout: float = 300.0 - - # Discovered metadata - protected_resource_metadata: ProtectedResourceMetadata | None = None - oauth_metadata: OAuthMetadata | None = None - auth_server_url: str | None = None - protocol_version: str | None = None - - # Client registration - client_info: OAuthClientInformationFull | None = None - - # Token management - current_tokens: OAuthToken | None = None - token_expiry_time: float | None = None - - # State - lock: anyio.Lock = field(default_factory=anyio.Lock) - - # Discovery state for fallback support - discovery_base_url: str | None = None - discovery_pathname: str | None = None - - def get_authorization_base_url(http://github.com/self,%20server_url:%20str) -> str: - """Extract base URL by removing path component.""" - parsed = urlparse(server_url) - return f"{parsed.scheme}://{parsed.netloc}" - - def update_token_expiry(self, token: OAuthToken) -> None: - """Update token expiry time.""" - if token.expires_in: - self.token_expiry_time = time.time() + token.expires_in - else: - self.token_expiry_time = None - - def is_token_valid(self) -> bool: - """Check if current token is valid.""" - return bool( - self.current_tokens - and self.current_tokens.access_token - and (not self.token_expiry_time or time.time() <= self.token_expiry_time) - ) - - def can_refresh_token(self) -> bool: - """Check if token can be refreshed.""" - return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) - - def clear_tokens(self) -> None: - """Clear current tokens.""" - self.current_tokens = None - self.token_expiry_time = None - - def get_resource_url(http://github.com/self) -> str: - """Get resource URL for RFC 8707. - - Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. - """ - resource = resource_url_from_server_url(http://github.com/self.server_url) - - # If PRM provides a resource that's a valid parent, use it - if self.protected_resource_metadata and self.protected_resource_metadata.resource: - prm_resource = str(self.protected_resource_metadata.resource) - if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): - resource = prm_resource - - return resource - - def should_include_resource_param(self, protocol_version: str | None = None) -> bool: - """Determine if the resource parameter should be included in OAuth requests. - - Returns True if: - - Protected resource metadata is available, OR - - MCP-Protocol-Version header is 2025-06-18 or later - """ - # If we have protected resource metadata, include the resource param - if self.protected_resource_metadata is not None: - return True - - # If no protocol version provided, don't include resource param - if not protocol_version: - return False - - # Check if protocol version is 2025-06-18 or later - # Version format is YYYY-MM-DD, so string comparison works - return protocol_version >= "2025-06-18" - - -class OAuthClientProvider(httpx.Auth): - """ - OAuth2 authentication for httpx. - Handles OAuth flow with automatic client registration and token storage. - """ - - requires_response_body = True - - def __init__( - self, - server_url: str, - client_metadata: OAuthClientMetadata, - storage: TokenStorage, - redirect_handler: Callable[[str], Awaitable[None]], - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], - timeout: float = 300.0, - ): - """Initialize OAuth2 authentication.""" - self.context = OAuthContext( - server_url=server_url, - client_metadata=client_metadata, - storage=storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - timeout=timeout, - ) - self._initialized = False - - def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: - """ - Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. - - Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise - """ - if not init_response or init_response.status_code != 401: - return None - - www_auth_header = init_response.headers.get("WWW-Authenticate") - if not www_auth_header: - return None - - # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) - pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) - - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) - - return None - - async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: - # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response - url = self._extract_resource_metadata_from_www_auth(init_response) - - if not url: - # Fallback to well-known discovery - auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") - - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _handle_protected_resource_response(self, response: httpx.Response) -> None: - """Handle discovery response.""" - if response.status_code == 200: - try: - content = await response.aread() - metadata = ProtectedResourceMetadata.model_validate_json(content) - self.context.protected_resource_metadata = metadata - if metadata.authorization_servers: - self.context.auth_server_url = str(metadata.authorization_servers[0]) - except ValidationError: - pass - - def _get_discovery_urls(self) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts.""" - urls: list[str] = [] - auth_server_url = self.context.auth_server_url or self.context.server_url - parsed = urlparse(auth_server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # RFC 8414: Path-aware OAuth discovery - if parsed.path and parsed.path != "/": - oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oauth_path)) - - # OAuth root fallback - urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - - # RFC 8414 section 5: Path-aware OIDC discovery - # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 - if parsed.path and parsed.path != "/": - oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oidc_path)) - - # OIDC 1.0 fallback (appends to full URL per OIDC spec) - oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" - urls.append(oidc_fallback) - - return urls - - async def _register_client(self) -> httpx.Request | None: - """Build registration request or skip if already registered.""" - if self.context.client_info: - return None - - if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: - registration_url = str(self.context.oauth_metadata.registration_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) - registration_url = urljoin(auth_base_url, "/register") - - registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - - return httpx.Request( - "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} - ) - - async def _handle_registration_response(self, response: httpx.Response) -> None: - """Handle registration response.""" - if response.status_code not in (200, 201): - await response.aread() - raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") - - try: - content = await response.aread() - client_info = OAuthClientInformationFull.model_validate_json(content) - self.context.client_info = client_info - await self.context.storage.set_client_info(client_info) - except ValidationError as e: - raise OAuthRegistrationError(f"Invalid registration response: {e}") - - async def _perform_authorization(self) -> tuple[str, str]: - """Perform the authorization redirect and get auth code.""" - if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) - auth_endpoint = urljoin(auth_base_url, "/authorize") - - if not self.context.client_info: - raise OAuthFlowError("No client info available for authorization") - - # Generate PKCE parameters - pkce_params = PKCEParameters.generate() - state = secrets.token_urlsafe(32) - - auth_params = { - "response_type": "code", - "client_id": self.context.client_info.client_id, - "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), - "state": state, - "code_challenge": pkce_params.code_challenge, - "code_challenge_method": "S256", - } - - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 - - if self.context.client_metadata.scope: - auth_params["scope"] = self.context.client_metadata.scope - - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" - await self.context.redirect_handler(authorization_url) - - # Wait for callback - auth_code, returned_state = await self.context.callback_handler() - - if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") - - if not auth_code: - raise OAuthFlowError("No authorization code received") - - # Return auth code and code verifier for token exchange - return auth_code, pkce_params.code_verifier - - async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request: - """Build token exchange request.""" - if not self.context.client_info: - raise OAuthFlowError("Missing client info") - - if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: - token_url = str(self.context.oauth_metadata.token_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) - token_url = urljoin(auth_base_url, "/token") - - token_data = { - "grant_type": "authorization_code", - "code": auth_code, - "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), - "client_id": self.context.client_info.client_id, - "code_verifier": code_verifier, - } - - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - token_data["resource"] = self.context.get_resource_url() # RFC 8707 - - if self.context.client_info.client_secret: - token_data["client_secret"] = self.context.client_info.client_secret - - return httpx.Request( - "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) - - async def _handle_token_response(self, response: httpx.Response) -> None: - """Handle token exchange response.""" - if response.status_code != 200: - raise OAuthTokenError(f"Token exchange failed: {response.status_code}") - - try: - content = await response.aread() - token_response = OAuthToken.model_validate_json(content) - - # Validate scopes - if token_response.scope and self.context.client_metadata.scope: - requested_scopes = set(self.context.client_metadata.scope.split()) - returned_scopes = set(token_response.scope.split()) - unauthorized_scopes = returned_scopes - requested_scopes - if unauthorized_scopes: - raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}") - - self.context.current_tokens = token_response - self.context.update_token_expiry(token_response) - await self.context.storage.set_tokens(token_response) - except ValidationError as e: - raise OAuthTokenError(f"Invalid token response: {e}") - - async def _refresh_token(self) -> httpx.Request: - """Build token refresh request.""" - if not self.context.current_tokens or not self.context.current_tokens.refresh_token: - raise OAuthTokenError("No refresh token available") - - if not self.context.client_info: - raise OAuthTokenError("No client info available") - - if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: - token_url = str(self.context.oauth_metadata.token_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) - token_url = urljoin(auth_base_url, "/token") - - refresh_data = { - "grant_type": "refresh_token", - "refresh_token": self.context.current_tokens.refresh_token, - "client_id": self.context.client_info.client_id, - } - - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - - if self.context.client_info.client_secret: - refresh_data["client_secret"] = self.context.client_info.client_secret - - return httpx.Request( - "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) - - async def _handle_refresh_response(self, response: httpx.Response) -> bool: - """Handle token refresh response. Returns True if successful.""" - if response.status_code != 200: - logger.warning(f"Token refresh failed: {response.status_code}") - self.context.clear_tokens() - return False - - try: - content = await response.aread() - token_response = OAuthToken.model_validate_json(content) - - self.context.current_tokens = token_response - self.context.update_token_expiry(token_response) - await self.context.storage.set_tokens(token_response) - - return True - except ValidationError: - logger.exception("Invalid refresh response") - self.context.clear_tokens() - return False - - async def _initialize(self) -> None: - """Load stored tokens and client info.""" - self.context.current_tokens = await self.context.storage.get_tokens() - self.context.client_info = await self.context.storage.get_client_info() - self._initialized = True - - def _add_auth_header(self, request: httpx.Request) -> None: - """Add authorization header to request if we have valid tokens.""" - if self.context.current_tokens and self.context.current_tokens.access_token: - request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" - - def _create_oauth_metadata_request(self, url: str) -> httpx.Request: - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self.context.oauth_metadata = metadata - # Apply default scope if needed - if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: - self.context.client_metadata.scope = " ".join(metadata.scopes_supported) - - async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: - """HTTPX auth flow integration.""" - async with self.context.lock: - if not self._initialized: - await self._initialize() - - # Capture protocol version from request headers - self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - - if not self.context.is_token_valid() and self.context.can_refresh_token(): - # Try to refresh token - refresh_request = await self._refresh_token() - refresh_response = yield refresh_request - - if not await self._handle_refresh_response(refresh_response): - # Refresh failed, need full re-authentication - self._initialized = False - - if self.context.is_token_valid(): - self._add_auth_header(request) - - response = yield request - - if response.status_code == 401: - # Perform full OAuth flow - try: - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) - discovery_request = await self._discover_protected_resource(response) - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - discovery_urls = self._get_discovery_urls() - for url in discovery_urls: - oauth_metadata_request = self._create_oauth_metadata_request(url) - oauth_metadata_response = yield oauth_metadata_request - - if oauth_metadata_response.status_code == 200: - try: - await self._handle_oauth_metadata_response(oauth_metadata_response) - break - except ValidationError: - continue - elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: - break # Non-4XX error, stop trying - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - except Exception: - logger.exception("OAuth flow error") - raise - - # Retry with new tokens - self._add_auth_header(request) - yield request diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py new file mode 100644 index 0000000000..252dfd9e4c --- /dev/null +++ b/src/mcp/client/auth/__init__.py @@ -0,0 +1,21 @@ +""" +OAuth2 Authentication implementation for HTTPX. + +Implements authorization code flow with PKCE and automatic token refresh. +""" + +from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError +from mcp.client.auth.oauth2 import ( + OAuthClientProvider, + PKCEParameters, + TokenStorage, +) + +__all__ = [ + "OAuthClientProvider", + "OAuthFlowError", + "OAuthRegistrationError", + "OAuthTokenError", + "PKCEParameters", + "TokenStorage", +] diff --git a/src/mcp/client/auth/exceptions.py b/src/mcp/client/auth/exceptions.py new file mode 100644 index 0000000000..5ce8777b86 --- /dev/null +++ b/src/mcp/client/auth/exceptions.py @@ -0,0 +1,10 @@ +class OAuthFlowError(Exception): + """Base exception for OAuth flow errors.""" + + +class OAuthTokenError(OAuthFlowError): + """Raised when token operations fail.""" + + +class OAuthRegistrationError(OAuthFlowError): + """Raised when client registration fails.""" diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py new file mode 100644 index 0000000000..e2f3f08a4d --- /dev/null +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -0,0 +1,487 @@ +""" +OAuth client credential extensions for MCP. + +Provides OAuth providers for machine-to-machine authentication flows: +- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret +- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication + (typically using a pre-built JWT from workload identity federation) +- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1) +""" + +import time +from collections.abc import Awaitable, Callable +from typing import Any, Literal +from uuid import uuid4 + +import httpx +import jwt +from pydantic import BaseModel, Field + +from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + + +class ClientCredentialsOAuthProvider(OAuthClientProvider): + """OAuth provider for client_credentials grant with client_id + client_secret. + + This provider sets client_info directly, bypassing dynamic client registration. + Use this when you already have client credentials (client_id and client_secret). + + Example: + ```python + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/", + storage=my_token_storage, + client_id="my-client-id", + client_secret="my-client-secret", + ) + ``` + """ + + def __init__( + self, + server_url: str, + storage: TokenStorage, + client_id: str, + client_secret: str, + token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic", + scopes: str | None = None, + ) -> None: + """Initialize client_credentials OAuth provider. + + Args: + server_url: The MCP server URL. + storage: Token storage implementation. + client_id: The OAuth client ID. + client_secret: The OAuth client secret. + token_endpoint_auth_method: Authentication method for token endpoint. + Either "client_secret_basic" (default) or "client_secret_post". + scopes: Optional space-separated list of scopes to request. + """ + # Build minimal client_metadata for the base class + client_metadata = OAuthClientMetadata( + redirect_uris=None, + grant_types=["client_credentials"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + super().__init__(server_url, client_metadata, storage, None, None, 300.0) + # Store client_info to be set during _initialize - no dynamic registration needed + self._fixed_client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=client_id, + client_secret=client_secret, + grant_types=["client_credentials"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + + async def _initialize(self) -> None: + """Load stored tokens and set pre-configured client_info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = self._fixed_client_info + self._initialized = True + + async def _perform_authorization(self) -> httpx.Request: + """Perform client_credentials authorization.""" + return await self._exchange_token_client_credentials() + + async def _exchange_token_client_credentials(self) -> httpx.Request: + """Build token exchange request for client_credentials grant.""" + token_data: dict[str, Any] = { + "grant_type": "client_credentials", + } + + headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + + # Use standard auth methods (client_secret_basic, client_secret_post, none) + token_data, headers = self.context.prepare_token_auth(token_data, headers) + + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + token_url = self._get_token_endpoint() + return httpx.Request("POST", token_url, data=token_data, headers=headers) + + +def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]: + """Create an assertion provider that returns a static JWT token. + + Use this when you have a pre-built JWT (e.g., from workload identity federation) + that doesn't need the audience parameter. + + Example: + ```python + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=static_assertion_provider(my_prebuilt_jwt), + ) + ``` + + Args: + token: The pre-built JWT assertion string. + + Returns: + An async callback suitable for use as an assertion_provider. + """ + + async def provider(audience: str) -> str: + return token + + return provider + + +class SignedJWTParameters(BaseModel): + """Parameters for creating SDK-signed JWT assertions. + + Use `create_assertion_provider()` to create an assertion provider callback + for use with `PrivateKeyJWTOAuthProvider`. + + Example: + ```python + jwt_params = SignedJWTParameters( + issuer="my-client-id", + subject="my-client-id", + signing_key=private_key_pem, + ) + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=jwt_params.create_assertion_provider(), + ) + ``` + """ + + issuer: str = Field(description="Issuer for JWT assertions (typically client_id).") + subject: str = Field(description="Subject identifier for JWT assertions (typically client_id).") + signing_key: str = Field(description="Private key for JWT signing (PEM format).") + signing_algorithm: str = Field(default="RS256", description="Algorithm for signing JWT assertions.") + lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.") + additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.") + + def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]: + """Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider. + + Returns: + An async callback that takes the audience (authorization server issuer URL) + and returns a signed JWT assertion. + """ + + async def provider(audience: str) -> str: + now = int(time.time()) + claims: dict[str, Any] = { + "iss": self.issuer, + "sub": self.subject, + "aud": audience, + "exp": now + self.lifetime_seconds, + "iat": now, + "jti": str(uuid4()), + } + if self.additional_claims: + claims.update(self.additional_claims) + + return jwt.encode(claims, self.signing_key, algorithm=self.signing_algorithm) + + return provider + + +class PrivateKeyJWTOAuthProvider(OAuthClientProvider): + """OAuth provider for client_credentials grant with private_key_jwt authentication. + + Uses RFC 7523 Section 2.2 for client authentication via JWT assertion. + + The JWT assertion's audience MUST be the authorization server's issuer identifier + (per RFC 7523bis security updates). The `assertion_provider` callback receives + this audience value and must return a JWT with that audience. + + **Option 1: Pre-built JWT via Workload Identity Federation** + + In production scenarios, the JWT assertion is typically obtained from a workload + identity provider (e.g., GCP, AWS IAM, Azure AD): + + ```python + async def get_workload_identity_token(audience: str) -> str: + # Fetch JWT from your identity provider + # The JWT's audience must match the provided audience parameter + return await fetch_token_from_identity_provider(audience=audience) + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=get_workload_identity_token, + ) + ``` + + **Option 2: Static pre-built JWT** + + If you have a static JWT that doesn't need the audience parameter: + + ```python + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=static_assertion_provider(my_prebuilt_jwt), + ) + ``` + + **Option 3: SDK-signed JWT (for testing/simple setups)** + + For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`: + + ```python + jwt_params = SignedJWTParameters( + issuer="my-client-id", + subject="my-client-id", + signing_key=private_key_pem, + ) + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=jwt_params.create_assertion_provider(), + ) + ``` + """ + + def __init__( + self, + server_url: str, + storage: TokenStorage, + client_id: str, + assertion_provider: Callable[[str], Awaitable[str]], + scopes: str | None = None, + ) -> None: + """Initialize private_key_jwt OAuth provider. + + Args: + server_url: The MCP server URL. + storage: Token storage implementation. + client_id: The OAuth client ID. + assertion_provider: Async callback that takes the audience (authorization + server's issuer identifier) and returns a JWT assertion. Use + `SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs, + `static_assertion_provider()` for pre-built JWTs, or provide your own + callback for workload identity federation. + scopes: Optional space-separated list of scopes to request. + """ + # Build minimal client_metadata for the base class + client_metadata = OAuthClientMetadata( + redirect_uris=None, + grant_types=["client_credentials"], + token_endpoint_auth_method="private_key_jwt", + scope=scopes, + ) + super().__init__(server_url, client_metadata, storage, None, None, 300.0) + self._assertion_provider = assertion_provider + # Store client_info to be set during _initialize - no dynamic registration needed + self._fixed_client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=client_id, + grant_types=["client_credentials"], + token_endpoint_auth_method="private_key_jwt", + scope=scopes, + ) + + async def _initialize(self) -> None: + """Load stored tokens and set pre-configured client_info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = self._fixed_client_info + self._initialized = True + + async def _perform_authorization(self) -> httpx.Request: + """Perform client_credentials authorization with private_key_jwt.""" + return await self._exchange_token_client_credentials() + + async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None: + """Add JWT assertion for client authentication to token endpoint parameters.""" + if not self.context.oauth_metadata: + raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover + + # Audience MUST be the issuer identifier of the authorization server + # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01 + audience = str(self.context.oauth_metadata.issuer) + assertion = await self._assertion_provider(audience) + + # RFC 7523 Section 2.2: client authentication via JWT + token_data["client_assertion"] = assertion + token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + + async def _exchange_token_client_credentials(self) -> httpx.Request: + """Build token exchange request for client_credentials grant with private_key_jwt.""" + token_data: dict[str, Any] = { + "grant_type": "client_credentials", + } + + headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + + # Add JWT client authentication (RFC 7523 Section 2.2) + await self._add_client_authentication_jwt(token_data=token_data) + + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + token_url = self._get_token_endpoint() + return httpx.Request("POST", token_url, data=token_data, headers=headers) + + +class JWTParameters(BaseModel): + """JWT parameters.""" + + assertion: str | None = Field( + default=None, + description="JWT assertion for JWT authentication. " + "Will be used instead of generating a new assertion if provided.", + ) + + issuer: str | None = Field(default=None, description="Issuer for JWT assertions.") + subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.") + audience: str | None = Field(default=None, description="Audience for JWT assertions.") + claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.") + jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.") + jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.") + jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.") + + def to_assertion(self, with_audience_fallback: str | None = None) -> str: + if self.assertion is not None: + # Prebuilt JWT (e.g. acquired out-of-band) + assertion = self.assertion + else: + if not self.jwt_signing_key: + raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover + if not self.issuer: + raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover + if not self.subject: + raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover + + audience = self.audience if self.audience else with_audience_fallback + if not audience: + raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover + + now = int(time.time()) + claims: dict[str, Any] = { + "iss": self.issuer, + "sub": self.subject, + "aud": audience, + "exp": now + self.jwt_lifetime_seconds, + "iat": now, + "jti": str(uuid4()), + } + claims.update(self.claims or {}) + + assertion = jwt.encode( + claims, + self.jwt_signing_key, + algorithm=self.jwt_signing_algorithm or "RS256", + ) + return assertion + + +class RFC7523OAuthClientProvider(OAuthClientProvider): + """OAuth client provider for RFC 7523 jwt-bearer grant. + + .. deprecated:: + Use :class:`ClientCredentialsOAuthProvider` for client_credentials with + client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for + client_credentials with private_key_jwt authentication instead. + + This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1) + where the JWT itself is the authorization grant. + """ + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, + timeout: float = 300.0, + jwt_parameters: JWTParameters | None = None, + ) -> None: + import warnings + + warnings.warn( + "RFC7523OAuthClientProvider is deprecated. Use ClientCredentialsOAuthProvider " + "or PrivateKeyJWTOAuthProvider instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout) + self.jwt_parameters = jwt_parameters + + async def _exchange_token_authorization_code( + self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None + ) -> httpx.Request: # pragma: no cover + """Build token exchange request for authorization_code flow.""" + token_data = token_data or {} + if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt": + self._add_client_authentication_jwt(token_data=token_data) + return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data) + + async def _perform_authorization(self) -> httpx.Request: # pragma: no cover + """Perform the authorization flow.""" + if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types: + token_request = await self._exchange_token_jwt_bearer() + return token_request + else: + return await super()._perform_authorization() + + def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover + """Add JWT assertion for client authentication to token endpoint parameters.""" + if not self.jwt_parameters: + raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow") + if not self.context.oauth_metadata: + raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow") + + # We need to set the audience to the issuer identifier of the authorization server + # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523 + issuer = str(self.context.oauth_metadata.issuer) + assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer) + + # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2 + token_data["client_assertion"] = assertion + token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + # We need to set the audience to the resource server, the audience is difference from the one in claims + # it represents the resource server that will validate the token + token_data["audience"] = self.context.get_resource_url() + + async def _exchange_token_jwt_bearer(self) -> httpx.Request: + """Build token exchange request for JWT bearer grant.""" + if not self.context.client_info: + raise OAuthFlowError("Missing client info") # pragma: no cover + if not self.jwt_parameters: + raise OAuthFlowError("Missing JWT parameters") # pragma: no cover + if not self.context.oauth_metadata: + raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover + + # We need to set the audience to the issuer identifier of the authorization server + # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523 + issuer = str(self.context.oauth_metadata.issuer) + assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer) + + token_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": assertion, + } + + if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch + token_data["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: # pragma: no branch + token_data["scope"] = self.context.client_metadata.scope + + token_url = self._get_token_endpoint() + return httpx.Request( + "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py new file mode 100644 index 0000000000..ddc61ef663 --- /dev/null +++ b/src/mcp/client/auth/oauth2.py @@ -0,0 +1,616 @@ +""" +OAuth2 Authentication implementation for HTTPX. + +Implements authorization code flow with PKCE and automatic token refresh. +""" + +import base64 +import hashlib +import logging +import secrets +import string +import time +from collections.abc import AsyncGenerator, Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any, Protocol +from urllib.parse import quote, urlencode, urljoin, urlparse + +import anyio +import httpx +from pydantic import BaseModel, Field, ValidationError + +from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError +from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + create_client_info_from_metadata_url, + create_client_registration_request, + create_oauth_metadata_request, + extract_field_from_www_auth, + extract_resource_metadata_from_www_auth, + extract_scope_from_www_auth, + get_client_metadata_scopes, + handle_auth_metadata_response, + handle_protected_resource_response, + handle_registration_response, + handle_token_response_scopes, + is_valid_client_metadata_url, + should_use_client_metadata_url, +) +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) +from mcp.shared.auth_utils import ( + calculate_token_expiry, + check_resource_allowed, + resource_url_from_server_url, +) + +logger = logging.getLogger(__name__) + + +class PKCEParameters(BaseModel): + """PKCE (Proof Key for Code Exchange) parameters.""" + + code_verifier: str = Field(..., min_length=43, max_length=128) + code_challenge: str = Field(..., min_length=43, max_length=128) + + @classmethod + def generate(cls) -> "PKCEParameters": + """Generate new PKCE parameters.""" + code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") + return cls(code_verifier=code_verifier, code_challenge=code_challenge) + + +class TokenStorage(Protocol): + """Protocol for token storage implementations.""" + + async def get_tokens(self) -> OAuthToken | None: + """Get stored tokens.""" + ... + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens.""" + ... + + async def get_client_info(self) -> OAuthClientInformationFull | None: + """Get stored client information.""" + ... + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information.""" + ... + + +@dataclass +class OAuthContext: + """OAuth flow context.""" + + server_url: str + client_metadata: OAuthClientMetadata + storage: TokenStorage + redirect_handler: Callable[[str], Awaitable[None]] | None + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None + timeout: float = 300.0 + client_metadata_url: str | None = None + + # Discovered metadata + protected_resource_metadata: ProtectedResourceMetadata | None = None + oauth_metadata: OAuthMetadata | None = None + auth_server_url: str | None = None + protocol_version: str | None = None + + # Client registration + client_info: OAuthClientInformationFull | None = None + + # Token management + current_tokens: OAuthToken | None = None + token_expiry_time: float | None = None + + # State + lock: anyio.Lock = field(default_factory=anyio.Lock) + + def get_authorization_base_url(http://github.com/self,%20server_url:%20str) -> str: + """Extract base URL by removing path component.""" + parsed = urlparse(server_url) + return f"{parsed.scheme}://{parsed.netloc}" + + def update_token_expiry(self, token: OAuthToken) -> None: + """Update token expiry time using shared util function.""" + self.token_expiry_time = calculate_token_expiry(token.expires_in) + + def is_token_valid(self) -> bool: + """Check if current token is valid.""" + return bool( + self.current_tokens + and self.current_tokens.access_token + and (not self.token_expiry_time or time.time() <= self.token_expiry_time) + ) + + def can_refresh_token(self) -> bool: + """Check if token can be refreshed.""" + return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) + + def clear_tokens(self) -> None: + """Clear current tokens.""" + self.current_tokens = None + self.token_expiry_time = None + + def get_resource_url(http://github.com/self) -> str: + """Get resource URL for RFC 8707. + + Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. + """ + resource = resource_url_from_server_url(http://github.com/self.server_url) + + # If PRM provides a resource that's a valid parent, use it + if self.protected_resource_metadata and self.protected_resource_metadata.resource: + prm_resource = str(self.protected_resource_metadata.resource) + if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): + resource = prm_resource + + return resource + + def should_include_resource_param(self, protocol_version: str | None = None) -> bool: + """Determine if the resource parameter should be included in OAuth requests. + + Returns True if: + - Protected resource metadata is available, OR + - MCP-Protocol-Version header is 2025-06-18 or later + """ + # If we have protected resource metadata, include the resource param + if self.protected_resource_metadata is not None: + return True + + # If no protocol version provided, don't include resource param + if not protocol_version: + return False + + # Check if protocol version is 2025-06-18 or later + # Version format is YYYY-MM-DD, so string comparison works + return protocol_version >= "2025-06-18" + + def prepare_token_auth( + self, data: dict[str, str], headers: dict[str, str] | None = None + ) -> tuple[dict[str, str], dict[str, str]]: + """Prepare authentication for token requests. + + Args: + data: The form data to send + headers: Optional headers dict to update + + Returns: + Tuple of (updated_data, updated_headers) + """ + if headers is None: + headers = {} # pragma: no cover + + if not self.client_info: + return data, headers # pragma: no cover + + auth_method = self.client_info.token_endpoint_auth_method + + if auth_method == "client_secret_basic" and self.client_info.client_id and self.client_info.client_secret: + # URL-encode client ID and secret per RFC 6749 Section 2.3.1 + encoded_id = quote(self.client_info.client_id, safe="") + encoded_secret = quote(self.client_info.client_secret, safe="") + credentials = f"{encoded_id}:{encoded_secret}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded_credentials}" + # Don't include client_secret in body for basic auth + data = {k: v for k, v in data.items() if k != "client_secret"} + elif auth_method == "client_secret_post" and self.client_info.client_secret: + # Include client_secret in request body + data["client_secret"] = self.client_info.client_secret + # For auth_method == "none", don't add any client_secret + + return data, headers + + +class OAuthClientProvider(httpx.Auth): + """ + OAuth2 authentication for httpx. + Handles OAuth flow with automatic client registration and token storage. + """ + + requires_response_body = True + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, + timeout: float = 300.0, + client_metadata_url: str | None = None, + ): + """Initialize OAuth2 authentication. + + Args: + server_url: The MCP server URL. + client_metadata: OAuth client metadata for registration. + storage: Token storage implementation. + redirect_handler: Handler for authorization redirects. + callback_handler: Handler for authorization callbacks. + timeout: Timeout for the OAuth flow. + client_metadata_url: URL-based client ID. When provided and the server + advertises client_id_metadata_document_supported=true, this URL will be + used as the client_id instead of performing dynamic client registration. + Must be a valid HTTPS URL with a non-root pathname. + + Raises: + ValueError: If client_metadata_url is provided but not a valid HTTPS URL + with a non-root pathname. + """ + # Validate client_metadata_url if provided + if client_metadata_url is not None and not is_valid_client_metadata_url(http://github.com/client_metadata_url): + raise ValueError( + f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}" + ) + + self.context = OAuthContext( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=timeout, + client_metadata_url=client_metadata_url, + ) + self._initialized = False + + async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: + """ + Handle protected resource metadata discovery response. + + Per SEP-985, supports fallback when discovery fails at one URL. + + Returns: + True if metadata was successfully discovered, False if we should try next URL + """ + if response.status_code == 200: + try: + content = await response.aread() + metadata = ProtectedResourceMetadata.model_validate_json(content) + self.context.protected_resource_metadata = metadata + if metadata.authorization_servers: # pragma: no branch + self.context.auth_server_url = str(metadata.authorization_servers[0]) + return True + + except ValidationError: # pragma: no cover + # Invalid metadata - try next URL + logger.warning(f"Invalid protected resource metadata at {response.request.url}") + return False + elif response.status_code == 404: # pragma: no cover + # Not found - try next URL in fallback chain + logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") + return False + else: + # Other error - fail immediately + raise OAuthFlowError( + f"Protected Resource Metadata request failed: {response.status_code}" + ) # pragma: no cover + + async def _perform_authorization(self) -> httpx.Request: + """Perform the authorization flow.""" + auth_code, code_verifier = await self._perform_authorization_code_grant() + token_request = await self._exchange_token_authorization_code(auth_code, code_verifier) + return token_request + + async def _perform_authorization_code_grant(self) -> tuple[str, str]: + """Perform the authorization redirect and get auth code.""" + if self.context.client_metadata.redirect_uris is None: + raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover + if not self.context.redirect_handler: + raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover + if not self.context.callback_handler: + raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover + + if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover + else: + auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) + auth_endpoint = urljoin(auth_base_url, "/authorize") + + if not self.context.client_info: + raise OAuthFlowError("No client info available for authorization") # pragma: no cover + + # Generate PKCE parameters + pkce_params = PKCEParameters.generate() + state = secrets.token_urlsafe(32) + + auth_params = { + "response_type": "code", + "client_id": self.context.client_info.client_id, + "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), + "state": state, + "code_challenge": pkce_params.code_challenge, + "code_challenge_method": "S256", + } + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover + + if self.context.client_metadata.scope: # pragma: no branch + auth_params["scope"] = self.context.client_metadata.scope + + authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + await self.context.redirect_handler(authorization_url) + + # Wait for callback + auth_code, returned_state = await self.context.callback_handler() + + if returned_state is None or not secrets.compare_digest(returned_state, state): + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover + + if not auth_code: + raise OAuthFlowError("No authorization code received") # pragma: no cover + + # Return auth code and code verifier for token exchange + return auth_code, pkce_params.code_verifier + + def _get_token_endpoint(self) -> str: + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) + token_url = urljoin(auth_base_url, "/token") + return token_url + + async def _exchange_token_authorization_code( + self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {} + ) -> httpx.Request: + """Build token exchange request for authorization_code flow.""" + if self.context.client_metadata.redirect_uris is None: + raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover + if not self.context.client_info: + raise OAuthFlowError("Missing client info") # pragma: no cover + + token_url = self._get_token_endpoint() + token_data = token_data or {} + token_data.update( + { + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), + "client_id": self.context.client_info.client_id, + "code_verifier": code_verifier, + } + ) + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() # RFC 8707 + + # Prepare authentication based on preferred method + headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_data, headers = self.context.prepare_token_auth(token_data, headers) + + return httpx.Request("POST", token_url, data=token_data, headers=headers) + + async def _handle_token_response(self, response: httpx.Response) -> None: + """Handle token exchange response.""" + if response.status_code not in {200, 201}: + body = await response.aread() # pragma: no cover + body_text = body.decode("utf-8") # pragma: no cover + raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover + + # Parse and validate response with scope validation + token_response = await handle_token_response_scopes(response) + + # Store tokens in context + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) + + async def _refresh_token(self) -> httpx.Request: + """Build token refresh request.""" + if not self.context.current_tokens or not self.context.current_tokens.refresh_token: + raise OAuthTokenError("No refresh token available") # pragma: no cover + + if not self.context.client_info or not self.context.client_info.client_id: + raise OAuthTokenError("No client info available") # pragma: no cover + + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover + else: + auth_base_url = self.context.get_authorization_base_url(http://github.com/self.context.server_url) + token_url = urljoin(auth_base_url, "/token") + + refresh_data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": self.context.current_tokens.refresh_token, + "client_id": self.context.client_info.client_id, + } + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 + + # Prepare authentication based on preferred method + headers = {"Content-Type": "application/x-www-form-urlencoded"} + refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) + + return httpx.Request("POST", token_url, data=refresh_data, headers=headers) + + async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + """Handle token refresh response. Returns True if successful.""" + if response.status_code != 200: + logger.warning(f"Token refresh failed: {response.status_code}") + self.context.clear_tokens() + return False + + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) + + return True + except ValidationError: + logger.exception("Invalid refresh response") + self.context.clear_tokens() + return False + + async def _initialize(self) -> None: # pragma: no cover + """Load stored tokens and client info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = await self.context.storage.get_client_info() + self._initialized = True + + def _add_auth_header(self, request: httpx.Request) -> None: + """Add authorization header to request if we have valid tokens.""" + if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch + request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: + content = await response.aread() + metadata = OAuthMetadata.model_validate_json(content) + self.context.oauth_metadata = metadata + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + """HTTPX auth flow integration.""" + async with self.context.lock: + if not self._initialized: + await self._initialize() # pragma: no cover + + # Capture protocol version from request headers + self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) + + if not self.context.is_token_valid() and self.context.can_refresh_token(): + # Try to refresh token + refresh_request = await self._refresh_token() # pragma: no cover + refresh_response = yield refresh_request # pragma: no cover + + if not await self._handle_refresh_response(refresh_response): # pragma: no cover + # Refresh failed, need full re-authentication + self._initialized = False + + if self.context.is_token_valid(): + self._add_auth_header(request) + + response = yield request + + if response.status_code == 401: + # Perform full OAuth flow + try: + # OAuth flow must be inline due to generator constraints + www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) + + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) + prm_discovery_urls = build_protected_resource_metadata_discovery_urls( + www_auth_resource_metadata_url, self.context.server_url + ) + + for url in prm_discovery_urls: # pragma: no branch + discovery_request = create_oauth_metadata_request(url) + + discovery_response = yield discovery_request # sending request + + prm = await handle_protected_resource_response(discovery_response) + if prm: + self.context.protected_resource_metadata = prm + + # todo: try all authorization_servers to find the OASM + assert ( + len(prm.authorization_servers) > 0 + ) # this is always true as authorization_servers has a min length of 1 + + self.context.auth_server_url = str(prm.authorization_servers[0]) + break + else: + logger.debug(f"Protected resource metadata discovery failed: {url}") + + asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + self.context.auth_server_url, self.context.server_url + ) + + # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) + for url in asm_discovery_urls: # pragma: no cover + oauth_metadata_request = create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request + + ok, asm = await handle_auth_metadata_response(oauth_metadata_response) + if not ok: + break + if ok and asm: + self.context.oauth_metadata = asm + break + else: + logger.debug(f"OAuth metadata discovery failed: {url}") + + # Step 3: Apply scope selection strategy + self.context.client_metadata.scope = get_client_metadata_scopes( + extract_scope_from_www_auth(response), + self.context.protected_resource_metadata, + self.context.oauth_metadata, + ) + + # Step 4: Register client or use URL-based client ID (CIMD) + if not self.context.client_info: + if should_use_client_metadata_url( + self.context.oauth_metadata, self.context.client_metadata_url + ): + # Use URL-based client ID (CIMD) + logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") + client_information = create_client_info_from_metadata_url( + self.context.client_metadata_url, # type: ignore[arg-type] + redirect_uris=self.context.client_metadata.redirect_uris, + ) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + else: + # Fallback to Dynamic Client Registration + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(http://github.com/self.context.server_url), + ) + registration_response = yield registration_request + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + + # Step 5: Perform authorization and complete token exchange + token_response = yield await self._perform_authorization() + await self._handle_token_response(token_response) + except Exception: # pragma: no cover + logger.exception("OAuth flow error") + raise + + # Retry with new tokens + self._add_auth_header(request) + yield request + elif response.status_code == 403: + # Step 1: Extract error field from WWW-Authenticate header + error = extract_field_from_www_auth(response, "error") + + # Step 2: Check if we need to step-up authorization + if error == "insufficient_scope": # pragma: no branch + try: + # Step 2a: Update the required scopes + self.context.client_metadata.scope = get_client_metadata_scopes( + extract_scope_from_www_auth(response), self.context.protected_resource_metadata + ) + + # Step 2b: Perform (re-)authorization and token exchange + token_response = yield await self._perform_authorization() + await self._handle_token_response(token_response) + except Exception: # pragma: no cover + logger.exception("OAuth flow error") + raise + + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py new file mode 100644 index 0000000000..b4426be7f8 --- /dev/null +++ b/src/mcp/client/auth/utils.py @@ -0,0 +1,336 @@ +import logging +import re +from urllib.parse import urljoin, urlparse + +from httpx import Request, Response +from pydantic import AnyUrl, ValidationError + +from mcp.client.auth import OAuthRegistrationError, OAuthTokenError +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) +from mcp.types import LATEST_PROTOCOL_VERSION + +logger = logging.getLogger(__name__) + + +def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: + """ + Extract field from WWW-Authenticate header. + + Returns: + Field value if found in WWW-Authenticate header, None otherwise + """ + www_auth_header = response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: field_name="value" or field_name=value (unquoted) + pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + +def extract_scope_from_www_auth(response: Response) -> str | None: + """ + Extract scope parameter from WWW-Authenticate header as per RFC6750. + + Returns: + Scope string if found in WWW-Authenticate header, None otherwise + """ + return extract_field_from_www_auth(response, "scope") + + +def extract_resource_metadata_from_www_auth(response: Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not response or response.status_code != 401: + return None # pragma: no cover + + return extract_field_from_www_auth(response, "resource_metadata") + + +def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]: + """ + Build ordered list of URLs to try for protected resource metadata discovery. + + Per SEP-985, the client MUST: + 1. Try resource_metadata from WWW-Authenticate header (if present) + 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} + 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource + + Args: + www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header + server_url: server url + + Returns: + Ordered list of URLs to try for discovery + """ + urls: list[str] = [] + + # Priority 1: WWW-Authenticate header with resource_metadata parameter + if www_auth_url: + urls.append(www_auth_url) + + # Priority 2-3: Well-known URIs (RFC 9728) + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # Priority 2: Path-based well-known URI (if server has a path component) + if parsed.path and parsed.path != "/": + path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") + urls.append(path_based_url) + + # Priority 3: Root-based well-known URI + root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") + urls.append(root_based_url) + + return urls + + +def get_client_metadata_scopes( + www_authenticate_scope: str | None, + protected_resource_metadata: ProtectedResourceMetadata | None, + authorization_server_metadata: OAuthMetadata | None = None, +) -> str | None: + """Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec.""" + # Per MCP spec, scope selection priority order: + # 1. Use scope from WWW-Authenticate header (if provided) + # 2. Use all scopes from PRM scopes_supported (if available) + # 3. Omit scope parameter if neither is available + + if www_authenticate_scope is not None: + # Priority 1: WWW-Authenticate header scope + return www_authenticate_scope + elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: + # Priority 2: PRM scopes_supported + return " ".join(protected_resource_metadata.scopes_supported) + elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: + return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover + else: + # Priority 3: Omit scope parameter + return None + + +def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]: + """ + Generate ordered list of (url, type) tuples for discovery attempts. + + Args: + auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None + server_url: URL for the MCP server, used as a fallback if auth_server_url is None + """ + + if not auth_server_url: + # Legacy path using the 2025-03-26 spec: + # link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization + parsed = urlparse(server_url) + return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"] + + urls: list[str] = [] + parsed = urlparse(auth_server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) + + # RFC 8414 section 5: Path-aware OIDC discovery + # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) + + # https://openid.net/specs/openid-connect-discovery-1_0.html + oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration" + urls.append(urljoin(base_url, oidc_path)) + return urls + + # OAuth root + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) + + # OIDC 1.0 fallback (appends to full URL per OIDC spec) + # https://openid.net/specs/openid-connect-discovery-1_0.html + urls.append(urljoin(base_url, "/.well-known/openid-configuration")) + + return urls + + +async def handle_protected_resource_response( + response: Response, +) -> ProtectedResourceMetadata | None: + """ + Handle protected resource metadata discovery response. + + Per SEP-985, supports fallback when discovery fails at one URL. + + Returns: + True if metadata was successfully discovered, False if we should try next URL + """ + if response.status_code == 200: + try: + content = await response.aread() + metadata = ProtectedResourceMetadata.model_validate_json(content) + return metadata + + except ValidationError: # pragma: no cover + # Invalid metadata - try next URL + return None + else: + # Not found - try next URL in fallback chain + return None + + +async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]: + if response.status_code == 200: + try: + content = await response.aread() + asm = OAuthMetadata.model_validate_json(content) + return True, asm + except ValidationError: # pragma: no cover + return True, None + elif response.status_code < 400 or response.status_code >= 500: + return False, None # Non-4XX error, stop trying + return True, None + + +def create_oauth_metadata_request(url: str) -> Request: + return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + +def create_client_registration_request( + auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str +) -> Request: + """Build registration request or skip if already registered.""" + + if auth_server_metadata and auth_server_metadata.registration_endpoint: + registration_url = str(auth_server_metadata.registration_endpoint) + else: + registration_url = urljoin(auth_base_url, "/register") + + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + + return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}) + + +async def handle_registration_response(response: Response) -> OAuthClientInformationFull: + """Handle registration response.""" + if response.status_code not in (200, 201): + await response.aread() + raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") + + try: + content = await response.aread() + client_info = OAuthClientInformationFull.model_validate_json(content) + return client_info + # self.context.client_info = client_info + # await self.context.storage.set_client_info(client_info) + except ValidationError as e: # pragma: no cover + raise OAuthRegistrationError(f"Invalid registration response: {e}") + + +def is_valid_client_metadata_url(/service/url: str | None) -> bool: + """Validate that a URL is suitable for use as a client_id (CIMD). + + The URL must be HTTPS with a non-root pathname. + + Args: + url: The URL to validate + + Returns: + True if the URL is a valid HTTPS URL with a non-root pathname + """ + if not url: + return False + try: + parsed = urlparse(url) + return parsed.scheme == "https" and parsed.path not in ("", "/") + except Exception: + return False + + +def should_use_client_metadata_url( + oauth_metadata: OAuthMetadata | None, + client_metadata_url: str | None, +) -> bool: + """Determine if URL-based client ID (CIMD) should be used instead of DCR. + + URL-based client IDs should be used when: + 1. The server advertises client_id_metadata_document_supported=true + 2. The client has a valid client_metadata_url configured + + Args: + oauth_metadata: OAuth authorization server metadata + client_metadata_url: URL-based client ID (already validated) + + Returns: + True if CIMD should be used, False if DCR should be used + """ + if not client_metadata_url: + return False + + if not oauth_metadata: + return False + + return oauth_metadata.client_id_metadata_document_supported is True + + +def create_client_info_from_metadata_url( + client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None +) -> OAuthClientInformationFull: + """Create client information using a URL-based client ID (CIMD). + + When using URL-based client IDs, the URL itself becomes the client_id + and no client_secret is used (token_endpoint_auth_method="none"). + + Args: + client_metadata_url: The URL to use as the client_id + redirect_uris: The redirect URIs from the client metadata (passed through for + compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata) + + Returns: + OAuthClientInformationFull with the URL as client_id + """ + return OAuthClientInformationFull( + client_id=client_metadata_url, + token_endpoint_auth_method="none", + redirect_uris=redirect_uris, + ) + + +async def handle_token_response_scopes( + response: Response, +) -> OAuthToken: + """Parse and validate token response with optional scope validation. + + Parses token response JSON. Callers should check response.status_code before calling. + + Args: + response: HTTP response from token endpoint (status already checked by caller) + + Returns: + Validated OAuthToken model + + Raises: + OAuthTokenError: If response JSON is invalid + """ + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + return token_response + except ValidationError as e: # pragma: no cover + raise OAuthTokenError(f"Invalid token response: {e}") diff --git a/src/mcp/client/experimental/__init__.py b/src/mcp/client/experimental/__init__.py new file mode 100644 index 0000000000..b6579b191e --- /dev/null +++ b/src/mcp/client/experimental/__init__.py @@ -0,0 +1,9 @@ +""" +Experimental client features. + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.client.experimental.tasks import ExperimentalClientFeatures + +__all__ = ["ExperimentalClientFeatures"] diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py new file mode 100644 index 0000000000..a47508674b --- /dev/null +++ b/src/mcp/client/experimental/task_handlers.py @@ -0,0 +1,290 @@ +""" +Experimental task handler protocols for server -> client requests. + +This module provides Protocol types and default handlers for when servers +send task-related requests to clients (the reverse of normal client -> server flow). + +WARNING: These APIs are experimental and may change without notice. + +Use cases: +- Server sends task-augmented sampling/elicitation request to client +- Client creates a local task, spawns background work, returns CreateTaskResult +- Server polls client's task status via tasks/get, tasks/result, etc. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol + +from pydantic import TypeAdapter + +import mcp.types as types +from mcp.shared.context import RequestContext +from mcp.shared.session import RequestResponder + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + + +class GetTaskHandlerFnT(Protocol): + """Handler for tasks/get requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.GetTaskRequestParams, + ) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch + + +class GetTaskResultHandlerFnT(Protocol): + """Handler for tasks/result requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.GetTaskPayloadRequestParams, + ) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch + + +class ListTasksHandlerFnT(Protocol): + """Handler for tasks/list requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.PaginatedRequestParams | None, + ) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch + + +class CancelTaskHandlerFnT(Protocol): + """Handler for tasks/cancel requests from server. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CancelTaskRequestParams, + ) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch + + +class TaskAugmentedSamplingFnT(Protocol): + """Handler for task-augmented sampling/createMessage requests from server. + + When server sends a CreateMessageRequest with task field, this callback + is invoked. The callback should create a task, spawn background work, + and return CreateTaskResult immediately. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + task_metadata: types.TaskMetadata, + ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch + + +class TaskAugmentedElicitationFnT(Protocol): + """Handler for task-augmented elicitation/create requests from server. + + When server sends an ElicitRequest with task field, this callback + is invoked. The callback should create a task, spawn background work, + and return CreateTaskResult immediately. + + WARNING: This is experimental and may change without notice. + """ + + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + task_metadata: types.TaskMetadata, + ) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch + + +async def default_get_task_handler( + context: RequestContext["ClientSession", Any], + params: types.GetTaskRequestParams, +) -> types.GetTaskResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/get not supported", + ) + + +async def default_get_task_result_handler( + context: RequestContext["ClientSession", Any], + params: types.GetTaskPayloadRequestParams, +) -> types.GetTaskPayloadResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/result not supported", + ) + + +async def default_list_tasks_handler( + context: RequestContext["ClientSession", Any], + params: types.PaginatedRequestParams | None, +) -> types.ListTasksResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/list not supported", + ) + + +async def default_cancel_task_handler( + context: RequestContext["ClientSession", Any], + params: types.CancelTaskRequestParams, +) -> types.CancelTaskResult | types.ErrorData: + return types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="tasks/cancel not supported", + ) + + +async def default_task_augmented_sampling( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + task_metadata: types.TaskMetadata, +) -> types.CreateTaskResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Task-augmented sampling not supported", + ) + + +async def default_task_augmented_elicitation( + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + task_metadata: types.TaskMetadata, +) -> types.CreateTaskResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Task-augmented elicitation not supported", + ) + + +@dataclass +class ExperimentalTaskHandlers: + """Container for experimental task handlers. + + Groups all task-related handlers that handle server -> client requests. + This includes both pure task requests (get, list, cancel, result) and + task-augmented request handlers (sampling, elicitation with task field). + + WARNING: These APIs are experimental and may change without notice. + + Example: + handlers = ExperimentalTaskHandlers( + get_task=my_get_task_handler, + list_tasks=my_list_tasks_handler, + ) + session = ClientSession(..., experimental_task_handlers=handlers) + """ + + # Pure task request handlers + get_task: GetTaskHandlerFnT = field(default=default_get_task_handler) + get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler) + list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler) + cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler) + + # Task-augmented request handlers + augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling) + augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation) + + def build_capability(self) -> types.ClientTasksCapability | None: + """Build ClientTasksCapability from the configured handlers. + + Returns a capability object that reflects which handlers are configured + (i.e., not using the default "not supported" handlers). + + Returns: + ClientTasksCapability if any handlers are provided, None otherwise + """ + has_list = self.list_tasks is not default_list_tasks_handler + has_cancel = self.cancel_task is not default_cancel_task_handler + has_sampling = self.augmented_sampling is not default_task_augmented_sampling + has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation + + # If no handlers are provided, return None + if not any([has_list, has_cancel, has_sampling, has_elicitation]): + return None + + # Build requests capability if any request handlers are provided + requests_capability: types.ClientTasksRequestsCapability | None = None + if has_sampling or has_elicitation: + requests_capability = types.ClientTasksRequestsCapability( + sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability()) + if has_sampling + else None, + elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability()) + if has_elicitation + else None, + ) + + return types.ClientTasksCapability( + list=types.TasksListCapability() if has_list else None, + cancel=types.TasksCancelCapability() if has_cancel else None, + requests=requests_capability, + ) + + @staticmethod + def handles_request(request: types.ServerRequest) -> bool: + """Check if this handler handles the given request type.""" + return isinstance( + request.root, + types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest, + ) + + async def handle_request( + self, + ctx: RequestContext["ClientSession", Any], + responder: RequestResponder[types.ServerRequest, types.ClientResult], + ) -> None: + """Handle a task-related request from the server. + + Call handles_request() first to check if this handler can handle the request. + """ + client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( + types.ClientResult | types.ErrorData + ) + + match responder.request.root: + case types.GetTaskRequest(params=params): + response = await self.get_task(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case types.GetTaskPayloadRequest(params=params): + response = await self.get_task_result(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case types.ListTasksRequest(params=params): + response = await self.list_tasks(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case types.CancelTaskRequest(params=params): + response = await self.cancel_task(ctx, params) + client_response = client_response_type.validate_python(response) + await responder.respond(client_response) + + case _: # pragma: no cover + raise ValueError(f"Unhandled request type: {type(responder.request.root)}") + + +# Backwards compatibility aliases +default_task_augmented_sampling_callback = default_task_augmented_sampling +default_task_augmented_elicitation_callback = default_task_augmented_elicitation diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py new file mode 100644 index 0000000000..ce9c387462 --- /dev/null +++ b/src/mcp/client/experimental/tasks.py @@ -0,0 +1,224 @@ +""" +Experimental client-side task support. + +This module provides client methods for interacting with MCP tasks. + +WARNING: These APIs are experimental and may change without notice. + +Example: + # Call a tool as a task + result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"}) + task_id = result.task.taskId + + # Get task status + status = await session.experimental.get_task(task_id) + + # Get task result when complete + if status.status == "completed": + result = await session.experimental.get_task_result(task_id, CallToolResult) + + # List all tasks + tasks = await session.experimental.list_tasks() + + # Cancel a task + await session.experimental.cancel_task(task_id) +""" + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, TypeVar + +import mcp.types as types +from mcp.shared.experimental.tasks.polling import poll_until_terminal + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + +ResultT = TypeVar("ResultT", bound=types.Result) + + +class ExperimentalClientFeatures: + """ + Experimental client features for tasks and other experimental APIs. + + WARNING: These APIs are experimental and may change without notice. + + Access via session.experimental: + status = await session.experimental.get_task(task_id) + """ + + def __init__(self, session: "ClientSession") -> None: + self._session = session + + async def call_tool_as_task( + self, + name: str, + arguments: dict[str, Any] | None = None, + *, + ttl: int = 60000, + meta: dict[str, Any] | None = None, + ) -> types.CreateTaskResult: + """Call a tool as a task, returning a CreateTaskResult for polling. + + This is a convenience method for calling tools that support task execution. + The server will return a task reference instead of the immediate result, + which can then be polled via `get_task()` and retrieved via `get_task_result()`. + + Args: + name: The tool name + arguments: Tool arguments + ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute) + meta: Optional metadata to include in the request + + Returns: + CreateTaskResult containing the task reference + + Example: + # Create task + result = await session.experimental.call_tool_as_task( + "long_running_tool", {"input": "data"} + ) + task_id = result.task.taskId + + # Poll for completion + while True: + status = await session.experimental.get_task(task_id) + if status.status == "completed": + break + await asyncio.sleep(0.5) + + # Get result + final = await session.experimental.get_task_result(task_id, CallToolResult) + """ + _meta: types.RequestParams.Meta | None = None + if meta is not None: + _meta = types.RequestParams.Meta(**meta) + + return await self._session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + task=types.TaskMetadata(ttl=ttl), + _meta=_meta, + ), + ) + ), + types.CreateTaskResult, + ) + + async def get_task(self, task_id: str) -> types.GetTaskResult: + """ + Get the current status of a task. + + Args: + task_id: The task identifier + + Returns: + GetTaskResult containing the task status and metadata + """ + return await self._session.send_request( + types.ClientRequest( + types.GetTaskRequest( + params=types.GetTaskRequestParams(taskId=task_id), + ) + ), + types.GetTaskResult, + ) + + async def get_task_result( + self, + task_id: str, + result_type: type[ResultT], + ) -> ResultT: + """ + Get the result of a completed task. + + The result type depends on the original request type: + - tools/call tasks return CallToolResult + - Other request types return their corresponding result type + + Args: + task_id: The task identifier + result_type: The expected result type (e.g., CallToolResult) + + Returns: + The task result, validated against result_type + """ + return await self._session.send_request( + types.ClientRequest( + types.GetTaskPayloadRequest( + params=types.GetTaskPayloadRequestParams(taskId=task_id), + ) + ), + result_type, + ) + + async def list_tasks( + self, + cursor: str | None = None, + ) -> types.ListTasksResult: + """ + List all tasks. + + Args: + cursor: Optional pagination cursor + + Returns: + ListTasksResult containing tasks and optional next cursor + """ + params = types.PaginatedRequestParams(cursor=cursor) if cursor else None + return await self._session.send_request( + types.ClientRequest( + types.ListTasksRequest(params=params), + ), + types.ListTasksResult, + ) + + async def cancel_task(self, task_id: str) -> types.CancelTaskResult: + """ + Cancel a running task. + + Args: + task_id: The task identifier + + Returns: + CancelTaskResult with the updated task state + """ + return await self._session.send_request( + types.ClientRequest( + types.CancelTaskRequest( + params=types.CancelTaskRequestParams(taskId=task_id), + ) + ), + types.CancelTaskResult, + ) + + async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: + """ + Poll a task until it reaches a terminal status. + + Yields GetTaskResult for each poll, allowing the caller to react to + status changes (e.g., handle input_required). Exits when task reaches + a terminal status (completed, failed, cancelled). + + Respects the pollInterval hint from the server. + + Args: + task_id: The task identifier + + Yields: + GetTaskResult for each poll + + Example: + async for status in session.experimental.poll_task(task_id): + print(f"Status: {status.status}") + if status.status == "input_required": + # Handle elicitation request via tasks/result + pass + + # Task is now terminal, get the result + result = await session.experimental.get_task_result(task_id, CallToolResult) + """ + async for status in poll_until_terminal(self.get_task, task_id): + yield status diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index bcf80d62a4..a7d03f87cd 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,13 +1,14 @@ import logging -from datetime import timedelta -from typing import Any, Protocol +from typing import Any, Protocol, overload import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from jsonschema import SchemaError, ValidationError, validate from pydantic import AnyUrl, TypeAdapter +from typing_extensions import deprecated import mcp.types as types +from mcp.client.experimental import ExperimentalClientFeatures +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -23,7 +24,7 @@ async def __call__( self, context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, - ) -> types.CreateMessageResult | types.ErrorData: ... + ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch class ElicitationFnT(Protocol): @@ -31,27 +32,27 @@ async def __call__( self, context: RequestContext["ClientSession", Any], params: types.ElicitRequestParams, - ) -> types.ElicitResult | types.ErrorData: ... + ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... + ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch class LoggingFnT(Protocol): async def __call__( self, params: types.LoggingMessageNotificationParams, - ) -> None: ... + ) -> None: ... # pragma: no branch class MessageHandlerFnT(Protocol): async def __call__( self, message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: ... + ) -> None: ... # pragma: no branch async def _default_message_handler( @@ -63,7 +64,7 @@ async def _default_message_handler( async def _default_sampling_callback( context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, -) -> types.CreateMessageResult | types.ErrorData: +) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, message="Sampling not supported", @@ -74,7 +75,7 @@ async def _default_elicitation_callback( context: RequestContext["ClientSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( + return types.ErrorData( # pragma: no cover code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -111,13 +112,16 @@ def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], write_stream: MemoryObjectSendStream[SessionMessage], - read_timeout_seconds: timedelta | None = None, + read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + *, + sampling_capabilities: types.SamplingCapability | None = None, + experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: super().__init__( read_stream, @@ -128,16 +132,31 @@ def __init__( ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback + self._sampling_capabilities = sampling_capabilities self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + self._server_capabilities: types.ServerCapabilities | None = None + self._experimental_features: ExperimentalClientFeatures | None = None + + # Experimental: Task handlers (use defaults if not provided) + self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() async def initialize(self) -> types.InitializeResult: - sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None + sampling = ( + (self._sampling_capabilities or types.SamplingCapability()) + if self._sampling_callback is not _default_sampling_callback + else None + ) elicitation = ( - types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None + types.ElicitationCapability( + form=types.FormElicitationCapability(), + url=types.UrlElicitationCapability(), + ) + if self._elicitation_callback is not _default_elicitation_callback + else None ) roots = ( # TODO: Should this be based on whether we @@ -158,6 +177,7 @@ async def initialize(self) -> types.InitializeResult: elicitation=elicitation, experimental=None, roots=roots, + tasks=self._task_handlers.build_capability(), ), clientInfo=self._client_info, ), @@ -169,10 +189,33 @@ async def initialize(self) -> types.InitializeResult: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + self._server_capabilities = result.capabilities + await self.send_notification(types.ClientNotification(types.InitializedNotification())) return result + def get_server_capabilities(self) -> types.ServerCapabilities | None: + """Return the server capabilities received during initialization. + + Returns None if the session has not been initialized yet. + """ + return self._server_capabilities + + @property + def experimental(self) -> ExperimentalClientFeatures: + """Experimental APIs for tasks and other features. + + WARNING: These APIs are experimental and may change without notice. + + Example: + status = await session.experimental.get_task(task_id) + result = await session.experimental.get_task_result(task_id, CallToolResult) + """ + if self._experimental_features is None: + self._experimental_features = ExperimentalClientFeatures(self) + return self._experimental_features + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( @@ -203,7 +246,7 @@ async def send_progress_notification( async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" - return await self.send_request( + return await self.send_request( # pragma: no cover types.ClientRequest( types.SetLevelRequest( params=types.SetLevelRequestParams(level=level), @@ -212,25 +255,79 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul types.EmptyResult, ) - async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult: - """Send a resources/list request.""" + @overload + @deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead") + async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ... + + @overload + async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ... + + @overload + async def list_resources(self) -> types.ListResourcesResult: ... + + async def list_resources( + self, + cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListResourcesResult: + """Send a resources/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ + if params is not None and cursor is not None: + raise ValueError("Cannot specify both cursor and params") + + if params is not None: + request_params = params + elif cursor is not None: + request_params = types.PaginatedRequestParams(cursor=cursor) + else: + request_params = None + return await self.send_request( - types.ClientRequest( - types.ListResourcesRequest( - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), + types.ClientRequest(types.ListResourcesRequest(params=request_params)), types.ListResourcesResult, ) - async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult: - """Send a resources/templates/list request.""" + @overload + @deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead") + async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ... + + @overload + async def list_resource_templates( + self, *, params: types.PaginatedRequestParams | None + ) -> types.ListResourceTemplatesResult: ... + + @overload + async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ... + + async def list_resource_templates( + self, + cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ + if params is not None and cursor is not None: + raise ValueError("Cannot specify both cursor and params") + + if params is not None: + request_params = params + elif cursor is not None: + request_params = types.PaginatedRequestParams(cursor=cursor) + else: + request_params = None + return await self.send_request( - types.ClientRequest( - types.ListResourceTemplatesRequest( - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), + types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)), types.ListResourceTemplatesResult, ) @@ -247,7 +344,7 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" - return await self.send_request( + return await self.send_request( # pragma: no cover types.ClientRequest( types.SubscribeRequest( params=types.SubscribeRequestParams(uri=uri), @@ -258,7 +355,7 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" - return await self.send_request( + return await self.send_request( # pragma: no cover types.ClientRequest( types.UnsubscribeRequest( params=types.UnsubscribeRequestParams(uri=uri), @@ -271,18 +368,21 @@ async def call_tool( self, name: str, arguments: dict[str, Any] | None = None, - read_timeout_seconds: timedelta | None = None, + read_timeout_seconds: float | None = None, progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" + _meta: types.RequestParams.Meta | None = None + if meta is not None: + _meta = types.RequestParams.Meta(**meta) + result = await self.send_request( types.ClientRequest( types.CallToolRequest( - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - ), + params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta), ) ), types.CallToolResult, @@ -308,23 +408,53 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - logger.warning(f"Tool {name} not listed by server, cannot validate any structured content") if output_schema is not None: + from jsonschema import SchemaError, ValidationError, validate + if result.structuredContent is None: - raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") + raise RuntimeError( + f"Tool {name} has an output schema but did not return structured content" + ) # pragma: no cover try: validate(result.structuredContent, output_schema) except ValidationError as e: - raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") - except SchemaError as e: - raise RuntimeError(f"Invalid schema for tool {name}: {e}") + raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover + except SchemaError as e: # pragma: no cover + raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover + + @overload + @deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead") + async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ... + + @overload + async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ... + + @overload + async def list_prompts(self) -> types.ListPromptsResult: ... + + async def list_prompts( + self, + cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListPromptsResult: + """Send a prompts/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ + if params is not None and cursor is not None: + raise ValueError("Cannot specify both cursor and params") + + if params is not None: + request_params = params + elif cursor is not None: + request_params = types.PaginatedRequestParams(cursor=cursor) + else: + request_params = None - async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: - """Send a prompts/list request.""" return await self.send_request( - types.ClientRequest( - types.ListPromptsRequest( - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), + types.ClientRequest(types.ListPromptsRequest(params=request_params)), types.ListPromptsResult, ) @@ -363,14 +493,40 @@ async def complete( types.CompleteResult, ) - async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: - """Send a tools/list request.""" + @overload + @deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead") + async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ... + + @overload + async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ... + + @overload + async def list_tools(self) -> types.ListToolsResult: ... + + async def list_tools( + self, + cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListToolsResult: + """Send a tools/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ + if params is not None and cursor is not None: + raise ValueError("Cannot specify both cursor and params") + + if params is not None: + request_params = params + elif cursor is not None: + request_params = types.PaginatedRequestParams(cursor=cursor) + else: + request_params = None + result = await self.send_request( - types.ClientRequest( - types.ListToolsRequest( - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), + types.ClientRequest(types.ListToolsRequest(params=request_params)), types.ListToolsResult, ) @@ -381,7 +537,7 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: return result - async def send_roots_list_changed(self) -> None: + async def send_roots_list_changed(self) -> None: # pragma: no cover """Send a roots/list_changed notification.""" await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) @@ -393,16 +549,31 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques lifespan_context=None, ) + # Delegate to experimental task handler if applicable + if self._task_handlers.handles_request(responder.request): + with responder: + await self._task_handlers.handle_request(ctx, responder) + return None + + # Core request handling match responder.request.root: case types.CreateMessageRequest(params=params): with responder: - response = await self._sampling_callback(ctx, params) + # Check if this is a task-augmented request + if params.task is not None: + response = await self._task_handlers.augmented_sampling(ctx, params, params.task) + else: + response = await self._sampling_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) case types.ElicitRequest(params=params): with responder: - response = await self._elicitation_callback(ctx, params) + # Check if this is a task-augmented request + if params.task is not None: + response = await self._task_handlers.augmented_elicitation(ctx, params, params.task) + else: + response = await self._elicitation_callback(ctx, params) client_response = ClientResponse.validate_python(response) await responder.respond(client_response) @@ -412,10 +583,15 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): + case types.PingRequest(): # pragma: no cover with responder: return await responder.respond(types.ClientResult(root=types.EmptyResult())) + case _: # pragma: no cover + pass # Task requests handled above by _task_handlers + + return None + async def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -429,5 +605,10 @@ async def _received_notification(self, notification: types.ServerNotification) - match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) + case types.ElicitCompleteNotification(params=params): + # Handle elicitation completion notification + # Clients MAY use this to retry requests or update UI + # The notification contains the elicitationId of the completed elicitation + pass case _: pass diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 700b5417fb..db0146068a 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,20 +11,24 @@ import contextlib import logging from collections.abc import Callable -from datetime import timedelta +from dataclasses import dataclass from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, overload import anyio +import httpx from pydantic import BaseModel -from typing_extensions import Self +from typing_extensions import Self, deprecated import mcp from mcp import types +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT class SseServerParameters(BaseModel): @@ -36,15 +40,15 @@ class SseServerParameters(BaseModel): # Optional headers to include in requests. headers: dict[str, Any] | None = None - # HTTP timeout for regular operations. - timeout: float = 5 + # HTTP timeout for regular operations (in seconds). + timeout: float = 5.0 - # Timeout for SSE read operations. - sse_read_timeout: float = 60 * 5 + # Timeout for SSE read operations (in seconds). + sse_read_timeout: float = 300.0 class StreamableHttpParameters(BaseModel): - """Parameters for intializing a streamablehttp_client.""" + """Parameters for intializing a streamable_http_client.""" # The endpoint URL. url: str @@ -52,11 +56,11 @@ class StreamableHttpParameters(BaseModel): # Optional headers to include in requests. headers: dict[str, Any] | None = None - # HTTP timeout for regular operations. - timeout: timedelta = timedelta(seconds=30) + # HTTP timeout for regular operations (in seconds). + timeout: float = 30.0 - # Timeout for SSE read operations. - sse_read_timeout: timedelta = timedelta(seconds=60 * 5) + # Timeout for SSE read operations (in seconds). + sse_read_timeout: float = 300.0 # Close the client session when the transport closes. terminate_on_close: bool = True @@ -65,6 +69,21 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters +# Use dataclass instead of pydantic BaseModel +# because pydantic BaseModel cannot handle Protocol fields. +@dataclass +class ClientSessionParameters: + """Parameters for establishing a client session to an MCP server.""" + + read_timeout_seconds: float | None = None + sampling_callback: SamplingFnT | None = None + elicitation_callback: ElicitationFnT | None = None + list_roots_callback: ListRootsFnT | None = None + logging_callback: LoggingFnT | None = None + message_handler: MessageHandlerFnT | None = None + client_info: types.Implementation | None = None + + class ClientSessionGroup: """Client for managing connections to multiple MCP servers. @@ -77,7 +96,7 @@ class ClientSessionGroup: Example Usage: name_fn = lambda name, server_info: f"{(server_info.name)}_{name}" async with ClientSessionGroup(component_name_hook=name_fn) as group: - for server_params in server_params: + for server_param in server_params: await group.connect_to_server(server_param) ... @@ -129,7 +148,7 @@ def __init__( self._session_exit_stacks = {} self._component_name_hook = component_name_hook - async def __aenter__(self) -> Self: + async def __aenter__(self) -> Self: # pragma: no cover # Enter the exit stack only if we created it ourselves if self._owns_exit_stack: await self._exit_stack.__aenter__() @@ -140,7 +159,7 @@ async def __aexit__( _exc_type: type[BaseException] | None, _exc_val: BaseException | None, _exc_tb: TracebackType | None, - ) -> bool | None: + ) -> bool | None: # pragma: no cover """Closes session exit stacks and main exit stack upon completion.""" # Only close the main exit stack if we created it @@ -155,7 +174,7 @@ async def __aexit__( @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" - return list(self._sessions.keys()) + return list(self._sessions.keys()) # pragma: no cover @property def prompts(self) -> dict[str, types.Prompt]: @@ -172,11 +191,49 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any], + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + @overload + @deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.") + async def call_tool( + self, + name: str, + *, + args: dict[str, Any], + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + args: dict[str, Any] | None = None, + ) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name - return await session.call_tool(session_tool_name, args) + return await session.call_tool( + session_tool_name, + arguments if args is None else args, + read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + meta=meta, + ) async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" @@ -192,7 +249,7 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None: ) ) - if session_known_for_components: + if session_known_for_components: # pragma: no cover component_names = self._sessions.pop(session) # Pop from _sessions tracking # Remove prompts associated with the session. @@ -212,8 +269,8 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None: # Clean up the session's resources via its dedicated exit stack if session_known_for_stack: - session_stack_to_close = self._session_exit_stacks.pop(session) - await session_stack_to_close.aclose() + session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover + await session_stack_to_close.aclose() # pragma: no cover async def connect_with_session( self, server_info: types.Implementation, session: mcp.ClientSession @@ -225,13 +282,16 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, + session_params: ClientSessionParameters | None = None, ) -> mcp.ClientSession: """Connects to a single MCP server.""" - server_info, session = await self._establish_session(server_params) + server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters()) return await self.connect_with_session(server_info, session) async def _establish_session( - self, server_params: ServerParameters + self, + server_params: ServerParameters, + session_params: ClientSessionParameters, ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" @@ -250,16 +310,36 @@ async def _establish_session( ) read, write = await session_stack.enter_async_context(client) else: - client = streamablehttp_client( - url=server_params.url, + httpx_client = create_mcp_http_client( headers=server_params.headers, - timeout=server_params.timeout, - sse_read_timeout=server_params.sse_read_timeout, + timeout=httpx.Timeout( + server_params.timeout, + read=server_params.sse_read_timeout, + ), + ) + await session_stack.enter_async_context(httpx_client) + + client = streamable_http_client( + url=server_params.url, + http_client=httpx_client, terminate_on_close=server_params.terminate_on_close, ) read, write, _ = await session_stack.enter_async_context(client) - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + session = await session_stack.enter_async_context( + mcp.ClientSession( + read, + write, + read_timeout_seconds=session_params.read_timeout_seconds, + sampling_callback=session_params.sampling_callback, + elicitation_callback=session_params.elicitation_callback, + list_roots_callback=session_params.list_roots_callback, + logging_callback=session_params.logging_callback, + message_handler=session_params.message_handler, + client_info=session_params.client_info, + ) + ) + result = await session.initialize() # Session successfully initialized. @@ -270,7 +350,7 @@ async def _establish_session( await self._exit_stack.enter_async_context(session_stack) return result.serverInfo, session - except Exception: + except Exception: # pragma: no cover # If anything during this setup fails, ensure the session-specific # stack is closed. await session_stack.aclose() @@ -298,7 +378,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session name = self._component_name(prompt.name, server_info) prompts_temp[name] = prompt component_names.prompts.add(name) - except McpError as err: + except McpError as err: # pragma: no cover logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. @@ -308,7 +388,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session name = self._component_name(resource.name, server_info) resources_temp[name] = resource component_names.resources.add(name) - except McpError as err: + except McpError as err: # pragma: no cover logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. @@ -319,18 +399,18 @@ async def _aggregate_components(self, server_info: types.Implementation, session tools_temp[name] = tool tool_to_session_temp[name] = session component_names.tools.add(name) - except McpError as err: + except McpError as err: # pragma: no cover logging.warning(f"Could not fetch tools: {err}") # Clean up exit stack for session if we couldn't retrieve anything # from the server. if not any((prompts_temp, resources_temp, tools_temp)): - del self._session_exit_stacks[session] + del self._session_exit_stacks[session] # pragma: no cover # Check for duplicates. matching_prompts = prompts_temp.keys() & self._prompts.keys() if matching_prompts: - raise McpError( + raise McpError( # pragma: no cover types.ErrorData( code=types.INVALID_PARAMS, message=f"{matching_prompts} already exist in group prompts.", @@ -338,7 +418,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session ) matching_resources = resources_temp.keys() & self._resources.keys() if matching_resources: - raise McpError( + raise McpError( # pragma: no cover types.ErrorData( code=types.INVALID_PARAMS, message=f"{matching_resources} already exist in group resources.", diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 791c602cdd..4b0bbbc1e7 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,7 +1,8 @@ import logging +from collections.abc import Callable from contextlib import asynccontextmanager from typing import Any -from urllib.parse import urljoin, urlparse +from urllib.parse import parse_qs, urljoin, urlparse import anyio import httpx @@ -21,14 +22,20 @@ def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) +def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None: + query_params = parse_qs(urlparse(endpoint_url).query) + return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0] + + @asynccontextmanager async def sse_client( url: str, headers: dict[str, Any] | None = None, - timeout: float = 5, - sse_read_timeout: float = 60 * 5, + timeout: float = 5.0, + sse_read_timeout: float = 300.0, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + on_session_created: Callable[[str], None] | None = None, ): """ Client transport for SSE. @@ -39,9 +46,10 @@ async def sse_client( Args: url: The SSE endpoint URL. headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations. - sse_read_timeout: Timeout for SSE read operations. + timeout: HTTP timeout for regular operations (in seconds). + sse_read_timeout: Timeout for SSE read operations (in seconds). auth: Optional HTTPX authentication handler. + on_session_created: Optional callback invoked with the session ID when received. """ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] @@ -70,7 +78,7 @@ async def sse_reader( task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, ): try: - async for sse in event_source.aiter_sse(): + async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": @@ -79,39 +87,47 @@ async def sse_reader( url_parsed = urlparse(url) endpoint_parsed = urlparse(endpoint_url) - if ( + if ( # pragma: no cover url_parsed.netloc != endpoint_parsed.netloc or url_parsed.scheme != endpoint_parsed.scheme ): - error_msg = ( + error_msg = ( # pragma: no cover f"Endpoint origin does not match connection origin: {endpoint_url}" ) - logger.error(error_msg) - raise ValueError(error_msg) + logger.error(error_msg) # pragma: no cover + raise ValueError(error_msg) # pragma: no cover + + if on_session_created: + session_id = _extract_session_id_from_endpoint(endpoint_url) + if session_id: + on_session_created(session_id) task_status.started(endpoint_url) case "message": + # Skip empty data (keep-alive pings) + if not sse.data: + continue try: message = types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ) logger.debug(f"Received server message: {message}") - except Exception as exc: - logger.exception("Error parsing server message") - await read_stream_writer.send(exc) - continue + except Exception as exc: # pragma: no cover + logger.exception("Error parsing server message") # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover + continue # pragma: no cover session_message = SessionMessage(message) await read_stream_writer.send(session_message) - case _: - logger.warning(f"Unknown SSE event: {sse.event}") - except SSEError as sse_exc: - logger.exception("Encountered SSE exception") - raise sse_exc - except Exception as exc: - logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) + case _: # pragma: no cover + logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover + except SSEError as sse_exc: # pragma: no cover + logger.exception("Encountered SSE exception") # pragma: no cover + raise sse_exc # pragma: no cover + except Exception as exc: # pragma: no cover + logger.exception("Error in sse_reader") # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover finally: await read_stream_writer.aclose() @@ -130,8 +146,8 @@ async def post_writer(endpoint_url: str): ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: - logger.exception("Error in post_writer") + except Exception: # pragma: no cover + logger.exception("Error in post_writer") # pragma: no cover finally: await write_stream.aclose() diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 6dc7c89afb..0d76bb958b 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -58,11 +58,11 @@ def get_default_environment() -> dict[str, str]: for key in DEFAULT_INHERITED_ENV_VARS: value = os.environ.get(key) if value is None: - continue + continue # pragma: no cover - if value.startswith("()"): + if value.startswith("()"): # pragma: no cover # Skip functions, which are a security risk - continue + continue # pragma: no cover env[key] = value @@ -153,14 +153,14 @@ async def stdout_reader(): for line in lines: try: message = types.JSONRPCMessage.model_validate_json(line) - except Exception as exc: + except Exception as exc: # pragma: no cover logger.exception("Failed to parse JSONRPC message from server") await read_stream_writer.send(exc) continue session_message = SessionMessage(message) await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() async def stdin_writer(): @@ -176,7 +176,7 @@ async def stdin_writer(): errors=server.encoding_error_handler, ) ) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() async with ( @@ -192,10 +192,10 @@ async def stdin_writer(): # 1. Close input stream to server # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time # 3. Send SIGKILL if still not exited - if process.stdin: + if process.stdin: # pragma: no branch try: await process.stdin.aclose() - except Exception: + except Exception: # pragma: no cover # stdin might already be closed, which is fine pass @@ -207,7 +207,7 @@ async def stdin_writer(): # Process didn't exit from stdin closure, use platform-specific termination # which handles SIGTERM -> SIGKILL escalation await _terminate_process_tree(process) - except ProcessLookupError: + except ProcessLookupError: # pragma: no cover # Process already exited, which is fine pass await read_stream.aclose() @@ -226,10 +226,10 @@ def _get_executable_command(command: str) -> str: Returns: str: Platform-appropriate command """ - if sys.platform == "win32": + if sys.platform == "win32": # pragma: no cover return get_windows_executable_command(command) else: - return command + return command # pragma: no cover async def _create_platform_compatible_process( @@ -245,7 +245,7 @@ async def _create_platform_compatible_process( Unix: Creates process in a new session/process group for killpg support Windows: Creates process in a Job Object for reliable child termination """ - if sys.platform == "win32": + if sys.platform == "win32": # pragma: no cover process = await create_windows_process(command, args, env, errlog, cwd) else: process = await anyio.open_process( @@ -254,7 +254,7 @@ async def _create_platform_compatible_process( stderr=errlog, cwd=cwd, start_new_session=True, - ) + ) # pragma: no cover return process @@ -270,9 +270,9 @@ async def _terminate_process_tree(process: Process | FallbackProcess, timeout_se process: The process to terminate timeout_seconds: Timeout in seconds before force killing (default: 2.0) """ - if sys.platform == "win32": + if sys.platform == "win32": # pragma: no cover await terminate_windows_process_tree(process, timeout_seconds) - else: + else: # pragma: no cover # FallbackProcess should only be used for Windows compatibility assert isinstance(process, Process) await terminate_posix_process_tree(process, timeout_seconds) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df647057..22645d3ba5 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -6,19 +6,26 @@ and session management. """ +import contextlib import logging from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta +from typing import Any, overload +from warnings import warn import anyio import httpx from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse +from typing_extensions import deprecated -from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared._httpx_utils import ( + McpHttpClientFactory, + create_mcp_http_client, +) from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -42,6 +49,10 @@ MCP_SESSION_ID = "mcp-session-id" MCP_PROTOCOL_VERSION = "mcp-protocol-version" LAST_EVENT_ID = "last-event-id" + +# Reconnection defaults +DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry +MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up CONTENT_TYPE = "content-type" ACCEPT = "accept" @@ -49,6 +60,9 @@ JSON = "application/json" SSE = "text/event-stream" +# Sentinel value for detecting unset optional parameters +_UNSET = object() + class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -63,52 +77,85 @@ class RequestContext: """Context for a request operation.""" client: httpx.AsyncClient - headers: dict[str, str] session_id: str | None session_message: SessionMessage metadata: ClientMessageMetadata | None read_stream_writer: StreamWriter - sse_read_timeout: float + headers: dict[str, str] | None = None # Deprecated - no longer used + sse_read_timeout: float | None = None # Deprecated - no longer used class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" + @overload + def __init__(self, url: str) -> None: ... + + @overload + @deprecated( + "Parameters headers, timeout, sse_read_timeout, and auth are deprecated. " + "Configure these on the httpx.AsyncClient instead." + ) def __init__( self, url: str, headers: dict[str, str] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, + timeout: float = 30.0, + sse_read_timeout: float = 300.0, auth: httpx.Auth | None = None, + ) -> None: ... + + def __init__( + self, + url: str, + headers: Any = _UNSET, + timeout: Any = _UNSET, + sse_read_timeout: Any = _UNSET, + auth: Any = _UNSET, ) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations. - sse_read_timeout: Timeout for SSE read operations. + timeout: HTTP timeout for regular operations (in seconds). + sse_read_timeout: Timeout for SSE read operations (in seconds). auth: Optional HTTPX authentication handler. """ + # Check for deprecated parameters and issue runtime warning + deprecated_params: list[str] = [] + if headers is not _UNSET: + deprecated_params.append("headers") + if timeout is not _UNSET: + deprecated_params.append("timeout") + if sse_read_timeout is not _UNSET: + deprecated_params.append("sse_read_timeout") + if auth is not _UNSET: + deprecated_params.append("auth") + + if deprecated_params: + warn( + f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. " + "Configure these on the httpx.AsyncClient instead.", + DeprecationWarning, + stacklevel=2, + ) + self.url = url - self.headers = headers or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - self.sse_read_timeout = ( - sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout - ) - self.auth = auth self.session_id = None self.protocol_version = None - self.request_headers = { - ACCEPT: f"{JSON}, {SSE}", - CONTENT_TYPE: JSON, - **self.headers, - } - - def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: - """Update headers with session ID and protocol version if available.""" - headers = base_headers.copy() + + def _prepare_headers(self) -> dict[str, str]: + """Build MCP-specific request headers. + + These headers will be merged with the httpx.AsyncClient's default headers, + with these MCP-specific headers taking precedence. + """ + headers: dict[str, str] = {} + # Add MCP protocol headers + headers[ACCEPT] = f"{JSON}, {SSE}" + headers[CONTENT_TYPE] = JSON + # Add session headers if available if self.session_id: headers[MCP_SESSION_ID] = self.session_id if self.protocol_version: @@ -138,14 +185,16 @@ def _maybe_extract_protocol_version_from_message( message: JSONRPCMessage, ) -> None: """Extract protocol version from initialization response message.""" - if isinstance(message.root, JSONRPCResponse) and message.root.result: + if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch try: # Parse the result as InitializeResult for type safety init_result = InitializeResult.model_validate(message.root.result) self.protocol_version = str(init_result.protocolVersion) logger.info(f"Negotiated protocol version: {self.protocol_version}") - except Exception as exc: - logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") + except Exception as exc: # pragma: no cover + logger.warning( + f"Failed to parse initialization response as InitializeResult: {exc}" + ) # pragma: no cover logger.warning(f"Raw result: {message.root.result}") async def _handle_sse_event( @@ -158,6 +207,12 @@ async def _handle_sse_event( ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": + # Handle priming events (empty data with ID) for resumability + if not sse.data: + # Call resumption callback for priming events that have an ID + if sse.id and resumption_callback: + await resumption_callback(sse.id) + return False try: message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"SSE message: {message}") @@ -181,11 +236,11 @@ async def _handle_sse_event( # Otherwise, return False to continue listening return isinstance(message.root, JSONRPCResponse | JSONRPCError) - except Exception as exc: + except Exception as exc: # pragma: no cover logger.exception("Error parsing SSE message") await read_stream_writer.send(exc) return False - else: + else: # pragma: no cover logger.warning(f"Unknown SSE event: {sse.event}") return False @@ -194,40 +249,66 @@ async def handle_get_stream( client: httpx.AsyncClient, read_stream_writer: StreamWriter, ) -> None: - """Handle GET stream for server-initiated messages.""" - try: - if not self.session_id: - return + """Handle GET stream for server-initiated messages with auto-reconnect.""" + last_event_id: str | None = None + retry_interval_ms: int | None = None + attempt: int = 0 - headers = self._prepare_request_headers(self.request_headers) + while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch + try: + if not self.session_id: + return - async with aconnect_sse( - client, - "GET", - self.url, - headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("GET SSE connection established") + headers = self._prepare_headers() + if last_event_id: + headers[LAST_EVENT_ID] = last_event_id # pragma: no cover - async for sse in event_source.aiter_sse(): - await self._handle_sse_event(sse, read_stream_writer) + async with aconnect_sse( + client, + "GET", + self.url, + headers=headers, + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + # Track last event ID for reconnection + if sse.id: + last_event_id = sse.id # pragma: no cover + # Track retry interval from server + if sse.retry is not None: + retry_interval_ms = sse.retry # pragma: no cover + + await self._handle_sse_event(sse, read_stream_writer) + + # Stream ended normally (server closed) - reset attempt counter + attempt = 0 + + except Exception as exc: # pragma: no cover + logger.debug(f"GET stream error: {exc}") + attempt += 1 + + if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") + return - except Exception as exc: - logger.debug(f"GET stream error (non-fatal): {exc}") + # Wait before reconnecting + delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...") + await anyio.sleep(delay_ms / 1000.0) async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._prepare_request_headers(ctx.headers) + headers = self._prepare_headers() if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: - raise ResumptionError("Resumption request requires a resumption token") + raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): + if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch original_request_id = ctx.session_message.message.root.id async with aconnect_sse( @@ -235,12 +316,11 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: "GET", self.url, headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") - async for sse in event_source.aiter_sse(): + async for sse in event_source.aiter_sse(): # pragma: no branch is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -253,7 +333,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._prepare_request_headers(ctx.headers) + headers = self._prepare_headers() message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -267,13 +347,13 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: + if response.status_code == 404: # pragma: no branch if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( - ctx.read_stream_writer, - message.root.id, - ) - return + await self._send_session_terminated_error( # pragma: no cover + ctx.read_stream_writer, # pragma: no cover + message.root.id, # pragma: no cover + ) # pragma: no cover + return # pragma: no cover response.raise_for_status() if is_initialization: @@ -288,10 +368,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: elif content_type.startswith(SSE): await self._handle_sse_response(response, ctx, is_initialization) else: - await self._handle_unexpected_content_type( - content_type, - ctx.read_stream_writer, - ) + await self._handle_unexpected_content_type( # pragma: no cover + content_type, # pragma: no cover + ctx.read_stream_writer, # pragma: no cover + ) # pragma: no cover async def _handle_json_response( self, @@ -310,7 +390,7 @@ async def _handle_json_response( session_message = SessionMessage(message) await read_stream_writer.send(session_message) - except Exception as exc: + except Exception as exc: # pragma: no cover logger.exception("Error parsing JSON response") await read_stream_writer.send(exc) @@ -321,9 +401,20 @@ async def _handle_sse_response( is_initialization: bool = False, ) -> None: """Handle SSE response from the server.""" + last_event_id: str | None = None + retry_interval_ms: int | None = None + try: event_source = EventSource(response) - async for sse in event_source.aiter_sse(): + async for sse in event_source.aiter_sse(): # pragma: no branch + # Track last event ID for potential reconnection + if sse.id: + last_event_id = sse.id + + # Track retry interval from server + if sse.retry is not None: + retry_interval_ms = sse.retry + is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, @@ -334,20 +425,87 @@ async def _handle_sse_response( # break the loop if is_complete: await response.aclose() - break - except Exception as e: - logger.exception("Error reading SSE stream:") - await ctx.read_stream_writer.send(e) + return # Normal completion, no reconnect needed + except Exception as e: # pragma: no cover + logger.debug(f"SSE stream ended: {e}") + + # Stream ended without response - reconnect if we received an event with ID + if last_event_id is not None: # pragma: no branch + logger.info("SSE stream disconnected, reconnecting...") + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + + async def _handle_reconnection( + self, + ctx: RequestContext, + last_event_id: str, + retry_interval_ms: int | None = None, + attempt: int = 0, + ) -> None: + """Reconnect with Last-Event-ID to resume stream after server disconnect.""" + # Bail if max retries exceeded + if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") + return + + # Always wait - use server value or default + delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS + await anyio.sleep(delay_ms / 1000.0) + + headers = self._prepare_headers() + headers[LAST_EVENT_ID] = last_event_id + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch + original_request_id = ctx.session_message.message.root.id + + try: + async with aconnect_sse( + ctx.client, + "GET", + self.url, + headers=headers, + ) as event_source: + event_source.response.raise_for_status() + logger.info("Reconnected to SSE stream") + + # Track for potential further reconnection + reconnect_last_event_id: str = last_event_id + reconnect_retry_ms = retry_interval_ms + + async for sse in event_source.aiter_sse(): + if sse.id: # pragma: no branch + reconnect_last_event_id = sse.id + if sse.retry is not None: + reconnect_retry_ms = sse.retry + + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + await event_source.response.aclose() + return + + # Stream ended again without response - reconnect again (reset attempt counter) + logger.info("SSE stream disconnected, reconnecting...") + await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) + except Exception as e: # pragma: no cover + logger.debug(f"Reconnection failed: {e}") + # Try to reconnect again if we still have an event ID + await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) async def _handle_unexpected_content_type( self, content_type: str, read_stream_writer: StreamWriter, - ) -> None: + ) -> None: # pragma: no cover """Handle unexpected content type in response.""" - error_msg = f"Unexpected content type: {content_type}" - logger.error(error_msg) - await read_stream_writer.send(ValueError(error_msg)) + error_msg = f"Unexpected content type: {content_type}" # pragma: no cover + logger.error(error_msg) # pragma: no cover + await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover async def _send_session_terminated_error( self, @@ -394,12 +552,10 @@ async def post_writer( ctx = RequestContext( client=client, - headers=self.request_headers, session_id=self.session_id, session_message=session_message, metadata=metadata, read_stream_writer=read_stream_writer, - sse_read_timeout=self.sse_read_timeout, ) async def handle_request_async(): @@ -415,18 +571,18 @@ async def handle_request_async(): await handle_request_async() except Exception: - logger.exception("Error in post_writer") + logger.exception("Error in post_writer") # pragma: no cover finally: await read_stream_writer.aclose() await write_stream.aclose() - async def terminate_session(self, client: httpx.AsyncClient) -> None: + async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover """Terminate the session by sending a DELETE request.""" if not self.session_id: return try: - headers = self._prepare_request_headers(self.request_headers) + headers = self._prepare_headers() response = await client.delete(self.url, headers=headers) if response.status_code == 405: @@ -442,14 +598,11 @@ def get_session_id(self) -> str | None: @asynccontextmanager -async def streamablehttp_client( +async def streamable_http_client( url: str, - headers: dict[str, str] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, + *, + http_client: httpx.AsyncClient | None = None, terminate_on_close: bool = True, - httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, - auth: httpx.Auth | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -461,30 +614,45 @@ async def streamablehttp_client( """ Client transport for StreamableHTTP. - `sse_read_timeout` determines how long (in seconds) the client will wait for a new - event before disconnecting. All other HTTP operations are controlled by `timeout`. + Args: + url: The MCP server endpoint URL. + http_client: Optional pre-configured httpx.AsyncClient. If None, a default + client with recommended MCP timeouts will be created. To configure headers, + authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. + terminate_on_close: If True, send a DELETE request to terminate the session + when the context exits. Yields: Tuple containing: - read_stream: Stream for reading messages from the server - write_stream: Stream for sending messages to the server - get_session_id_callback: Function to retrieve the current session ID - """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) + Example: + See examples/snippets/clients/ for usage patterns. + """ read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + # Determine if we need to create and manage the client + client_provided = http_client is not None + client = http_client + + if client is None: + # Create default client with recommended MCP timeouts + client = create_mcp_http_client() + + transport = StreamableHTTPTransport(url) + async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: - # Define callbacks that need access to tg + async with contextlib.AsyncExitStack() as stack: + # Only manage client lifecycle if we created it + if not client_provided: + await stack.enter_async_context(client) + def start_get_stream() -> None: tg.start_soon(transport.handle_get_stream, client, read_stream_writer) @@ -511,3 +679,44 @@ def start_get_stream() -> None: finally: await read_stream_writer.aclose() await write_stream.aclose() + + +@asynccontextmanager +@deprecated("Use `streamable_http_client` instead.") +async def streamablehttp_client( + url: str, + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, + terminate_on_close: bool = True, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, + auth: httpx.Auth | None = None, +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback, + ], + None, +]: + # Convert timeout parameters + timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + sse_read_timeout_seconds = ( + sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout + ) + + # Create httpx client using the factory with old-style parameters + client = httpx_client_factory( + headers=headers, + timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds), + auth=auth, + ) + + # Manage client lifecycle since we created it + async with client: + async with streamable_http_client( + url, + http_client=client, + terminate_on_close=terminate_on_close, + ) as streams: + yield streams diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 0a371610bd..e8c8d9af87 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -59,7 +59,7 @@ async def ws_reader(): message = types.JSONRPCMessage.model_validate_json(raw_text) session_message = SessionMessage(message) await read_stream_writer.send(session_message) - except ValidationError as exc: + except ValidationError as exc: # pragma: no cover # If JSON parse or model validation fails, send the exception await read_stream_writer.send(exc) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 850f8373d6..3570d28c2a 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -50,7 +50,7 @@ class AuthorizationErrorResponse(BaseModel): def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None: - if params is None: + if params is None: # pragma: no cover return None value = params.get(key) if isinstance(value, str): @@ -116,7 +116,7 @@ async def error_response( pass # the error response MUST contain the state specified by the client, if any - if state is None: + if state is None: # pragma: no cover # make last-ditch effort to load state state = best_effort_extract_string("state", params) @@ -218,7 +218,7 @@ async def error_response( # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 return await error_response(error=e.error, error_description=e.error_description) - except Exception as validation_error: + except Exception as validation_error: # pragma: no cover # Catch-all for unexpected errors logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) return await error_response(error="server_error", error_description="An unexpected error occurred") diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 93720340a7..c65473d1fc 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -49,8 +49,13 @@ async def handle(self, request: Request) -> Response: ) client_id = str(uuid4()) + + # If auth method is None, default to client_secret_post + if client_metadata.token_endpoint_auth_method is None: + client_metadata.token_endpoint_auth_method = "client_secret_post" + client_secret = None - if client_metadata.token_endpoint_auth_method != "none": + if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch # cryptographically secure random 32-byte hex string client_secret = secrets.token_hex(32) @@ -59,7 +64,7 @@ async def handle(self, request: Request) -> Response: elif client_metadata.scope is not None and self.options.valid_scopes is not None: requested_scopes = set(client_metadata.scope.split()) valid_scopes = set(self.options.valid_scopes) - if not requested_scopes.issubset(valid_scopes): + if not requested_scopes.issubset(valid_scopes): # pragma: no branch return PydanticJSONResponse( content=RegistrationErrorResponse( error="invalid_client_metadata", diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 478ad7a011..fa8cfc99d0 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -40,28 +40,25 @@ async def handle(self, request: Request) -> Response: Handler for the OAuth 2.0 Token Revocation endpoint. """ try: - form_data = await request.form() - revocation_request = RevocationRequest.model_validate(dict(form_data)) - except ValidationError as e: + client = await self.client_authenticator.authenticate_request(request) + except AuthenticationError as e: # pragma: no cover return PydanticJSONResponse( - status_code=400, + status_code=401, content=RevocationErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(e), + error="unauthorized_client", + error_description=e.message, ), ) - # Authenticate client try: - client = await self.client_authenticator.authenticate( - revocation_request.client_id, revocation_request.client_secret - ) - except AuthenticationError as e: + form_data = await request.form() + revocation_request = RevocationRequest.model_validate(dict(form_data)) + except ValidationError as e: return PydanticJSONResponse( - status_code=401, + status_code=400, content=RevocationErrorResponse( - error="unauthorized_client", - error_description=e.message, + error="invalid_request", + error_description=stringify_pydantic_error(e), ), ) @@ -69,7 +66,7 @@ async def handle(self, request: Request) -> Response: self.provider.load_access_token, partial(self.provider.load_refresh_token, client), ] - if revocation_request.token_type_hint == "refresh_token": + if revocation_request.token_type_hint == "refresh_token": # pragma: no cover loaders = reversed(loaders) token: None | AccessToken | RefreshToken = None diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 4e15e6265c..7e8294ce6e 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -91,10 +91,26 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse): ) async def handle(self, request: Request): + try: + client_info = await self.client_authenticator.authenticate_request(request) + except AuthenticationError as e: + # Authentication failures should return 401 + return PydanticJSONResponse( + content=TokenErrorResponse( + error="invalid_client", + error_description=e.message, + ), + status_code=401, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) + try: form_data = await request.form() token_request = TokenRequest.model_validate(dict(form_data)).root - except ValidationError as validation_error: + except ValidationError as validation_error: # pragma: no cover return self.response( TokenErrorResponse( error="invalid_request", @@ -102,20 +118,7 @@ async def handle(self, request: Request): ) ) - try: - client_info = await self.client_authenticator.authenticate( - client_id=token_request.client_id, - client_secret=token_request.client_secret, - ) - except AuthenticationError as e: - return self.response( - TokenErrorResponse( - error="unauthorized_client", - error_description=e.message, - ) - ) - - if token_request.grant_type not in client_info.grant_types: + if token_request.grant_type not in client_info.grant_types: # pragma: no cover return self.response( TokenErrorResponse( error="unsupported_grant_type", @@ -151,7 +154,7 @@ async def handle(self, request: Request): # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 if auth_code.redirect_uri_provided_explicitly: authorize_request_redirect_uri = auth_code.redirect_uri - else: + else: # pragma: no cover authorize_request_redirect_uri = None # Convert both sides to strings for comparison to handle AnyUrl vs string issues @@ -192,7 +195,7 @@ async def handle(self, request: Request): ) ) - case RefreshTokenRequest(): + case RefreshTokenRequest(): # pragma: no cover refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) if refresh_token is None or refresh_token.client_id != token_request.client_id: # if token belongs to different client, pretend it doesn't exist diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6251e5ad5b..64c9b8841f 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -99,7 +99,7 @@ async def _send_auth_error(self, send: Send, status_code: int, error: str, descr """Send an authentication error response with WWW-Authenticate header.""" # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: + if self.resource_metadata_url: # pragma: no cover www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index d5f473b484..6126c6e4f9 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,5 +1,11 @@ +import base64 +import binascii +import hmac import time from typing import Any +from urllib.parse import unquote + +from starlette.requests import Request from mcp.server.auth.provider import OAuthAuthorizationServerProvider from mcp.shared.auth import OAuthClientInformationFull @@ -7,7 +13,7 @@ class AuthenticationError(Exception): def __init__(self, message: str): - self.message = message + self.message = message # pragma: no cover class ClientAuthenticator: @@ -30,22 +36,80 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """ self.provider = provider - async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull: - # Look up client information - client = await self.provider.get_client(client_id) + async def authenticate_request(self, request: Request) -> OAuthClientInformationFull: + """ + Authenticate a client from an HTTP request. + + Extracts client credentials from the appropriate location based on the + client's registered authentication method and validates them. + + Args: + request: The HTTP request containing client credentials + + Returns: + The authenticated client information + + Raises: + AuthenticationError: If authentication fails + """ + form_data = await request.form() + client_id = form_data.get("client_id") + if not client_id: + raise AuthenticationError("Missing client_id") + + client = await self.provider.get_client(str(client_id)) if not client: - raise AuthenticationError("Invalid client_id") + raise AuthenticationError("Invalid client_id") # pragma: no cover + + request_client_secret: str | None = None + auth_header = request.headers.get("Authorization", "") + + if client.token_endpoint_auth_method == "client_secret_basic": + if not auth_header.startswith("Basic "): + raise AuthenticationError("Missing or invalid Basic authentication in Authorization header") + + try: + encoded_credentials = auth_header[6:] # Remove "Basic " prefix + decoded = base64.b64decode(encoded_credentials).decode("utf-8") + if ":" not in decoded: + raise ValueError("Invalid Basic auth format") + basic_client_id, request_client_secret = decoded.split(":", 1) + + # URL-decode both parts per RFC 6749 Section 2.3.1 + basic_client_id = unquote(basic_client_id) + request_client_secret = unquote(request_client_secret) + + if basic_client_id != client_id: + raise AuthenticationError("Client ID mismatch in Basic auth") + except (ValueError, UnicodeDecodeError, binascii.Error): + raise AuthenticationError("Invalid Basic authentication header") + + elif client.token_endpoint_auth_method == "client_secret_post": + raw_form_data = form_data.get("client_secret") + # form_data.get() can return a UploadFile or None, so we need to check if it's a string + if isinstance(raw_form_data, str): + request_client_secret = str(raw_form_data) + + elif client.token_endpoint_auth_method == "none": + request_client_secret = None + else: + raise AuthenticationError( # pragma: no cover + f"Unsupported auth method: {client.token_endpoint_auth_method}" + ) # If client from the store expects a secret, validate that the request provides # that secret - if client.client_secret: - if not client_secret: - raise AuthenticationError("Client secret is required") + if client.client_secret: # pragma: no branch + if not request_client_secret: + raise AuthenticationError("Client secret is required") # pragma: no cover - if client.client_secret != client_secret: - raise AuthenticationError("Invalid client_secret") + # hmac.compare_digest requires that both arguments are either bytes or a `str` containing + # only ASCII characters. Since we do not control `request_client_secret`, we encode both + # arguments to bytes. + if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()): + raise AuthenticationError("Invalid client_secret") # pragma: no cover if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): - raise AuthenticationError("Client secret has expired") + raise AuthenticationError("Client secret has expired") # pragma: no cover return client diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index a7b1086027..96296c148e 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -117,7 +117,6 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: Returns: The client information, or None if the client does not exist. """ - ... async def register_client(self, client_info: OAuthClientInformationFull) -> None: """ @@ -132,7 +131,6 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None Raises: RegistrationError: If the client metadata is invalid. """ - ... async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: """ @@ -221,8 +219,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t Returns: The RefreshToken object if found, or None if not found. """ - - ... + ... async def exchange_refresh_token( self, @@ -258,7 +255,6 @@ async def load_access_token(self, token: str) -> AccessTokenT | None: Returns: The AuthInfo, or None if the token is invalid. """ - ... async def revoke_token( self, @@ -276,7 +272,6 @@ async def revoke_token( Args: token: the token to revoke """ - ... def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 862b9a2d9a..71a9c8b165 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -38,13 +38,13 @@ def validate_issuer_url(/service/url: AnyHttpUrl): and url.host != "localhost" and (url.host is not None and not url.host.startswith("127.0.0.1")) ): - raise ValueError("Issuer URL must be HTTPS") + raise ValueError("Issuer URL must be HTTPS") # pragma: no cover # No fragments or query parameters allowed if url.fragment: - raise ValueError("Issuer URL must not have a fragment") + raise ValueError("Issuer URL must not have a fragment") # pragma: no cover if url.query: - raise ValueError("Issuer URL must not have a query string") + raise ValueError("Issuer URL must not have a query string") # pragma: no cover AUTHORIZATION_PATH = "/authorize" @@ -115,7 +115,7 @@ def create_auth_routes( ), ] - if client_registration_options.enabled: + if client_registration_options.enabled: # pragma: no branch registration_handler = RegistrationHandler( provider, options=client_registration_options, @@ -131,7 +131,7 @@ def create_auth_routes( ) ) - if revocation_options.enabled: + if revocation_options.enabled: # pragma: no branch revocation_handler = RevocationHandler(provider, client_authenticator) routes.append( Route( @@ -165,7 +165,7 @@ def build_metadata( response_types_supported=["code"], response_modes_supported=None, grant_types_supported=["authorization_code", "refresh_token"], - token_endpoint_auth_methods_supported=["client_secret_post"], + token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"], token_endpoint_auth_signing_alg_values_supported=None, service_documentation=service_documentation_url, ui_locales_supported=None, @@ -176,13 +176,13 @@ def build_metadata( ) # Add registration endpoint if supported - if client_registration_options.enabled: + if client_registration_options.enabled: # pragma: no branch metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH) # Add revocation endpoint if supported - if revocation_options.enabled: + if revocation_options.enabled: # pragma: no branch metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) - metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] + metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"] return metadata diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 39e3212e91..49195415bf 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -3,10 +3,10 @@ from __future__ import annotations import types +from collections.abc import Sequence from typing import Generic, Literal, TypeVar, Union, get_args, get_origin from pydantic import BaseModel -from pydantic.fields import FieldInfo from mcp.server.session import ServerSession from mcp.types import RequestId @@ -36,6 +36,15 @@ class CancelledElicitation(BaseModel): ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation +class AcceptedUrlElicitation(BaseModel): + """Result when user accepts a URL mode elicitation.""" + + action: Literal["accept"] = "accept" + + +UrlElicitationResult = AcceptedUrlElicitation | DeclinedElicitation | CancelledElicitation + + # Primitive types allowed in elicitation schemas _ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) @@ -43,22 +52,40 @@ class CancelledElicitation(BaseModel): def _validate_elicitation_schema(schema: type[BaseModel]) -> None: """Validate that a Pydantic model only contains primitive field types.""" for field_name, field_info in schema.model_fields.items(): - if not _is_primitive_field(field_info): + annotation = field_info.annotation + + if annotation is None or annotation is types.NoneType: # pragma: no cover + continue + elif _is_primitive_field(annotation): + continue + elif _is_string_sequence(annotation): + continue + else: raise TypeError( f"Elicitation schema field '{field_name}' must be a primitive type " - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " - f"Complex types like lists, dicts, or nested models are not allowed." + f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), " + f"or Optional of these types. Nested models and complex types are not allowed." ) -def _is_primitive_field(field_info: FieldInfo) -> bool: - """Check if a field is a primitive type allowed in elicitation schemas.""" - annotation = field_info.annotation +def _is_string_sequence(annotation: type) -> bool: + """Check if annotation is a sequence of strings (list[str], Sequence[str], etc).""" + origin = get_origin(annotation) + # Check if it's a sequence-like type with str elements + if origin: + try: + if issubclass(origin, Sequence): + args = get_args(annotation) + # Should have single str type arg + return len(args) == 1 and args[0] is str + except TypeError: # pragma: no cover + # origin is not a class, so it can't be a subclass of Sequence + pass + return False - # Handle None type - if annotation is types.NoneType: - return True +def _is_primitive_field(annotation: type) -> bool: + """Check if a field is a primitive type allowed in elicitation schemas.""" # Handle basic primitive types if annotation in _ELICITATION_PRIMITIVE_TYPES: return True @@ -67,8 +94,10 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: origin = get_origin(annotation) if origin is Union or origin is types.UnionType: args = get_args(annotation) - # All args must be primitive types or None - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + # All args must be primitive types, None, or string sequences + return all( + arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args + ) return False @@ -79,20 +108,22 @@ async def elicit_with_validation( schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user with schema validation. + """Elicit information from the client/user with schema validation (form mode). This method can be used to interactively ask for additional information from the client within a tool's execution. The client might display the message to the user and collect a response according to the provided schema. Or in case a client is an agent, it might decide how to handle the elicitation -- either by asking the user or automatically generating a response. + + For sensitive data like credentials or OAuth flows, use elicit_url() instead. """ # Validate that schema only contains primitive types and fail loudly if not _validate_elicitation_schema(schema) json_schema = schema.model_json_schema() - result = await session.elicit( + result = await session.elicit_form( message=message, requestedSchema=json_schema, related_request_id=related_request_id, @@ -104,8 +135,56 @@ async def elicit_with_validation( return AcceptedElicitation(data=validated_data) elif result.action == "decline": return DeclinedElicitation() + elif result.action == "cancel": # pragma: no cover + return CancelledElicitation() + else: # pragma: no cover + # This should never happen, but handle it just in case + raise ValueError(f"Unexpected elicitation action: {result.action}") + + +async def elicit_url( + session: ServerSession, + message: str, + url: str, + elicitation_id: str, + related_request_id: RequestId | None = None, +) -> UrlElicitationResult: + """Elicit information from the user via out-of-band URL navigation (URL mode). + + This method directs the user to an external URL where sensitive interactions can + occur without passing data through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + the server should send an ElicitCompleteNotification to notify the client. + + Args: + session: The server session + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + result = await session.elicit_url( + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=related_request_id, + ) + + if result.action == "accept": + return AcceptedUrlElicitation() + elif result.action == "decline": + return DeclinedElicitation() elif result.action == "cancel": return CancelledElicitation() - else: + else: # pragma: no cover # This should never happen, but handle it just in case raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py new file mode 100644 index 0000000000..824bb8b8be --- /dev/null +++ b/src/mcp/server/experimental/__init__.py @@ -0,0 +1,11 @@ +""" +Server-side experimental features. + +WARNING: These APIs are experimental and may change without notice. + +Import directly from submodules: +- mcp.server.experimental.task_context.ServerTaskContext +- mcp.server.experimental.task_support.TaskSupport +- mcp.server.experimental.task_result_handler.TaskResultHandler +- mcp.server.experimental.request_context.Experimental +""" diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py new file mode 100644 index 0000000000..78e75beb6a --- /dev/null +++ b/src/mcp/server/experimental/request_context.py @@ -0,0 +1,238 @@ +""" +Experimental request context features. + +This module provides the Experimental class which gives access to experimental +features within a request context, such as task-augmented request handling. + +WARNING: These APIs are experimental and may change without notice. +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any + +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.experimental.task_support import TaskSupport +from mcp.server.session import ServerSession +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal +from mcp.types import ( + METHOD_NOT_FOUND, + TASK_FORBIDDEN, + TASK_REQUIRED, + ClientCapabilities, + CreateTaskResult, + ErrorData, + Result, + TaskExecutionMode, + TaskMetadata, + Tool, +) + + +@dataclass +class Experimental: + """ + Experimental features context for task-augmented requests. + + Provides helpers for validating task execution compatibility and + running tasks with automatic lifecycle management. + + WARNING: This API is experimental and may change without notice. + """ + + task_metadata: TaskMetadata | None = None + _client_capabilities: ClientCapabilities | None = field(default=None, repr=False) + _session: ServerSession | None = field(default=None, repr=False) + _task_support: TaskSupport | None = field(default=None, repr=False) + + @property + def is_task(self) -> bool: + """Check if this request is task-augmented.""" + return self.task_metadata is not None + + @property + def client_supports_tasks(self) -> bool: + """Check if the client declared task support.""" + if self._client_capabilities is None: + return False + return self._client_capabilities.tasks is not None + + def validate_task_mode( + self, + tool_task_mode: TaskExecutionMode | None, + *, + raise_error: bool = True, + ) -> ErrorData | None: + """ + Validate that the request is compatible with the tool's task execution mode. + + Per MCP spec: + - "required": Clients MUST invoke as task. Server returns -32601 if not. + - "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do. + - "optional": Either is acceptable. + + Args: + tool_task_mode: The tool's execution.taskSupport value + ("forbidden", "optional", "required", or None) + raise_error: If True, raises McpError on validation failure. If False, returns ErrorData. + + Returns: + None if valid, ErrorData if invalid and raise_error=False + + Raises: + McpError: If invalid and raise_error=True + """ + + mode = tool_task_mode or TASK_FORBIDDEN + + error: ErrorData | None = None + + if mode == TASK_REQUIRED and not self.is_task: + error = ErrorData( + code=METHOD_NOT_FOUND, + message="This tool requires task-augmented invocation", + ) + elif mode == TASK_FORBIDDEN and self.is_task: + error = ErrorData( + code=METHOD_NOT_FOUND, + message="This tool does not support task-augmented invocation", + ) + + if error is not None and raise_error: + raise McpError(error) + + return error + + def validate_for_tool( + self, + tool: Tool, + *, + raise_error: bool = True, + ) -> ErrorData | None: + """ + Validate that the request is compatible with the given tool. + + Convenience wrapper around validate_task_mode that extracts the mode from a Tool. + + Args: + tool: The Tool definition + raise_error: If True, raises McpError on validation failure. + + Returns: + None if valid, ErrorData if invalid and raise_error=False + """ + mode = tool.execution.taskSupport if tool.execution else None + return self.validate_task_mode(mode, raise_error=raise_error) + + def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: + """ + Check if this client can use a tool with the given task mode. + + Useful for filtering tool lists or providing warnings. + Returns False if tool requires "required" but client doesn't support tasks. + + Args: + tool_task_mode: The tool's execution.taskSupport value + + Returns: + True if the client can use this tool, False otherwise + """ + mode = tool_task_mode or TASK_FORBIDDEN + if mode == TASK_REQUIRED and not self.client_supports_tasks: + return False + return True + + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: str | None = None, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: + """ + Create a task, spawn background work, and return CreateTaskResult immediately. + + This is the recommended way to handle task-augmented tool calls. It: + 1. Creates a task in the store + 2. Spawns the work function in a background task + 3. Returns CreateTaskResult immediately + + The work function receives a ServerTaskContext with: + - elicit() for sending elicitation requests + - create_message() for sampling requests + - update_status() for progress updates + - complete()/fail() for finishing the task + + When work() returns a Result, the task is auto-completed with that result. + If work() raises an exception, the task is auto-failed. + + Args: + work: Async function that does the actual work + task_id: Optional task ID (generated if not provided) + model_immediate_response: Optional string to include in _meta as + io.modelcontextprotocol/model-immediate-response + + Returns: + CreateTaskResult to return to the client + + Raises: + RuntimeError: If task support is not enabled or task_metadata is missing + + Example: + @server.call_tool() + async def handle_tool(name: str, args: dict): + ctx = server.request_context + + async def work(task: ServerTaskContext) -> CallToolResult: + result = await task.elicit( + message="Are you sure?", + requestedSchema={"type": "object", ...} + ) + confirmed = result.content.get("confirm", False) + return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")]) + + return await ctx.experimental.run_task(work) + + WARNING: This API is experimental and may change without notice. + """ + if self._task_support is None: + raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") + if self._session is None: + raise RuntimeError("Session not available.") + if self.task_metadata is None: + raise RuntimeError( + "Request is not task-augmented (no task field in params). " + "The client must send a task-augmented request." + ) + + support = self._task_support + # Access task_group via TaskSupport - raises if not in run() context + task_group = support.task_group + + task = await support.store.create_task(self.task_metadata, task_id) + + task_ctx = ServerTaskContext( + task=task, + store=support.store, + session=self._session, + queue=support.queue, + handler=support.handler, + ) + + async def execute() -> None: + try: + result = await work(task_ctx) + if not is_terminal(task_ctx.task.status): + await task_ctx.complete(result) + except Exception as e: + if not is_terminal(task_ctx.task.status): + await task_ctx.fail(str(e)) + + task_group.start_soon(execute) + + meta: dict[str, Any] | None = None + if model_immediate_response is not None: + meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response} + + return CreateTaskResult(task=task, **{"_meta": meta} if meta else {}) diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py new file mode 100644 index 0000000000..4842da5175 --- /dev/null +++ b/src/mcp/server/experimental/session_features.py @@ -0,0 +1,220 @@ +""" +Experimental server session features for server→client task operations. + +This module provides the server-side equivalent of ExperimentalClientFeatures, +allowing the server to send task-augmented requests to the client and poll for results. + +WARNING: These APIs are experimental and may change without notice. +""" + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, TypeVar + +import mcp.types as types +from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared.experimental.tasks.capabilities import ( + require_task_augmented_elicitation, + require_task_augmented_sampling, +) +from mcp.shared.experimental.tasks.polling import poll_until_terminal + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + +ResultT = TypeVar("ResultT", bound=types.Result) + + +class ExperimentalServerSessionFeatures: + """ + Experimental server session features for server→client task operations. + + This provides the server-side equivalent of ExperimentalClientFeatures, + allowing the server to send task-augmented requests to the client and + poll for results. + + WARNING: These APIs are experimental and may change without notice. + + Access via session.experimental: + result = await session.experimental.elicit_as_task(...) + """ + + def __init__(self, session: "ServerSession") -> None: + self._session = session + + async def get_task(self, task_id: str) -> types.GetTaskResult: + """ + Send tasks/get to the client to get task status. + + Args: + task_id: The task identifier + + Returns: + GetTaskResult containing the task status + """ + return await self._session.send_request( + types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))), + types.GetTaskResult, + ) + + async def get_task_result( + self, + task_id: str, + result_type: type[ResultT], + ) -> ResultT: + """ + Send tasks/result to the client to retrieve the final result. + + Args: + task_id: The task identifier + result_type: The expected result type + + Returns: + The task result, validated against result_type + """ + return await self._session.send_request( + types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))), + result_type, + ) + + async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]: + """ + Poll a client task until it reaches terminal status. + + Yields GetTaskResult for each poll, allowing the caller to react to + status changes. Exits when task reaches a terminal status. + + Respects the pollInterval hint from the client. + + Args: + task_id: The task identifier + + Yields: + GetTaskResult for each poll + """ + async for status in poll_until_terminal(self.get_task, task_id): + yield status + + async def elicit_as_task( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + *, + ttl: int = 60000, + ) -> types.ElicitResult: + """ + Send a task-augmented elicitation to the client and poll until complete. + + The client will create a local task, process the elicitation asynchronously, + and return the result when ready. This method handles the full flow: + 1. Send elicitation with task field + 2. Receive CreateTaskResult from client + 3. Poll client's task until terminal + 4. Retrieve and return the final ElicitResult + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response + ttl: Task time-to-live in milliseconds + + Returns: + The client's elicitation response + + Raises: + McpError: If client doesn't support task-augmented elicitation + """ + client_caps = self._session.client_params.capabilities if self._session.client_params else None + require_task_augmented_elicitation(client_caps) + + create_result = await self._session.send_request( + types.ServerRequest( + types.ElicitRequest( + params=types.ElicitRequestFormParams( + message=message, + requestedSchema=requestedSchema, + task=types.TaskMetadata(ttl=ttl), + ) + ) + ), + types.CreateTaskResult, + ) + + task_id = create_result.task.taskId + + async for _ in self.poll_task(task_id): + pass + + return await self.get_task_result(task_id, types.ElicitResult) + + async def create_message_as_task( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + ttl: int = 60000, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool] | None = None, + tool_choice: types.ToolChoice | None = None, + ) -> types.CreateMessageResult: + """ + Send a task-augmented sampling request and poll until complete. + + The client will create a local task, process the sampling request + asynchronously, and return the result when ready. + + Args: + messages: The conversation messages for sampling + max_tokens: Maximum tokens in the response + ttl: Task time-to-live in milliseconds + system_prompt: Optional system prompt + include_context: Context inclusion strategy + temperature: Sampling temperature + stop_sequences: Stop sequences + metadata: Additional metadata + model_preferences: Model selection preferences + tools: Optional list of tools the LLM can use during sampling + tool_choice: Optional control over tool usage behavior + + Returns: + The sampling result from the client + + Raises: + McpError: If client doesn't support task-augmented sampling or tools + ValueError: If tool_use or tool_result message structure is invalid + """ + client_caps = self._session.client_params.capabilities if self._session.client_params else None + require_task_augmented_sampling(client_caps) + validate_sampling_tools(client_caps, tools, tool_choice) + validate_tool_use_result_messages(messages) + + create_result = await self._session.send_request( + types.ServerRequest( + types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + maxTokens=max_tokens, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + tools=tools, + toolChoice=tool_choice, + task=types.TaskMetadata(ttl=ttl), + ) + ) + ), + types.CreateTaskResult, + ) + + task_id = create_result.task.taskId + + async for _ in self.poll_task(task_id): + pass + + return await self.get_task_result(task_id, types.CreateMessageResult) diff --git a/src/mcp/server/experimental/task_context.py b/src/mcp/server/experimental/task_context.py new file mode 100644 index 0000000000..e6e14fc938 --- /dev/null +++ b/src/mcp/server/experimental/task_context.py @@ -0,0 +1,612 @@ +""" +ServerTaskContext - Server-integrated task context with elicitation and sampling. + +This wraps the pure TaskContext and adds server-specific functionality: +- Elicitation (task.elicit()) +- Sampling (task.create_message()) +- Status notifications +""" + +from typing import Any + +import anyio + +from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.server.session import ServerSession +from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.capabilities import ( + require_task_augmented_elicitation, + require_task_augmented_sampling, +) +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + INVALID_REQUEST, + TASK_STATUS_INPUT_REQUIRED, + TASK_STATUS_WORKING, + ClientCapabilities, + CreateMessageResult, + CreateTaskResult, + ElicitationCapability, + ElicitRequestedSchema, + ElicitResult, + ErrorData, + IncludeContext, + ModelPreferences, + RequestId, + Result, + SamplingCapability, + SamplingMessage, + ServerNotification, + Task, + TaskMetadata, + TaskStatusNotification, + TaskStatusNotificationParams, + Tool, + ToolChoice, +) + + +class ServerTaskContext: + """ + Server-integrated task context with elicitation and sampling. + + This wraps a pure TaskContext and adds server-specific functionality: + - elicit() for sending elicitation requests to the client + - create_message() for sampling requests + - Status notifications via the session + + Example: + async def my_task_work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Starting...") + + result = await task.elicit( + message="Continue?", + requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}} + ) + + if result.content.get("ok"): + return CallToolResult(content=[TextContent(text="Done!")]) + else: + return CallToolResult(content=[TextContent(text="Cancelled")]) + """ + + def __init__( + self, + *, + task: Task, + store: TaskStore, + session: ServerSession, + queue: TaskMessageQueue, + handler: TaskResultHandler | None = None, + ): + """ + Create a ServerTaskContext. + + Args: + task: The Task object + store: The task store + session: The server session + queue: The message queue for elicitation/sampling + handler: The result handler for response routing (required for elicit/create_message) + """ + self._ctx = TaskContext(task=task, store=store) + self._session = session + self._queue = queue + self._handler = handler + self._store = store + + # Delegate pure properties to inner context + + @property + def task_id(self) -> str: + """The task identifier.""" + return self._ctx.task_id + + @property + def task(self) -> Task: + """The current task state.""" + return self._ctx.task + + @property + def is_cancelled(self) -> bool: + """Whether cancellation has been requested.""" + return self._ctx.is_cancelled + + def request_cancellation(self) -> None: + """Request cancellation of this task.""" + self._ctx.request_cancellation() + + # Enhanced methods with notifications + + async def update_status(self, message: str, *, notify: bool = True) -> None: + """ + Update the task's status message. + + Args: + message: The new status message + notify: Whether to send a notification to the client + """ + await self._ctx.update_status(message) + if notify: + await self._send_notification() + + async def complete(self, result: Result, *, notify: bool = True) -> None: + """ + Mark the task as completed with the given result. + + Args: + result: The task result + notify: Whether to send a notification to the client + """ + await self._ctx.complete(result) + if notify: + await self._send_notification() + + async def fail(self, error: str, *, notify: bool = True) -> None: + """ + Mark the task as failed with an error message. + + Args: + error: The error message + notify: Whether to send a notification to the client + """ + await self._ctx.fail(error) + if notify: + await self._send_notification() + + async def _send_notification(self) -> None: + """Send a task status notification to the client.""" + task = self._ctx.task + await self._session.send_notification( + ServerNotification( + TaskStatusNotification( + params=TaskStatusNotificationParams( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + ) + ) + ) + + # Server-specific methods: elicitation and sampling + + def _check_elicitation_capability(self) -> None: + """Check if the client supports elicitation.""" + if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())): + raise McpError( + ErrorData( + code=INVALID_REQUEST, + message="Client does not support elicitation capability", + ) + ) + + def _check_sampling_capability(self) -> None: + """Check if the client supports sampling.""" + if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())): + raise McpError( + ErrorData( + code=INVALID_REQUEST, + message="Client does not support sampling capability", + ) + ) + + async def elicit( + self, + message: str, + requestedSchema: ElicitRequestedSchema, + ) -> ElicitResult: + """ + Send an elicitation request via the task message queue. + + This method: + 1. Checks client capability + 2. Updates task status to "input_required" + 3. Queues the elicitation request + 4. Waits for the response (delivered via tasks/result round-trip) + 5. Updates task status back to "working" + 6. Returns the result + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + + Returns: + The client's response + + Raises: + McpError: If client doesn't support elicitation capability + """ + self._check_elicitation_capability() + + if self._handler is None: + raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.") + + # Update status to input_required + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + + # Build the request using session's helper + request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] + message=message, + requestedSchema=requestedSchema, + related_task_id=self.task_id, + ) + request_id: RequestId = request.id + + resolver: Resolver[dict[str, Any]] = Resolver() + self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] + + queued = QueuedMessage( + type="request", + message=request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self.task_id, queued) + + try: + # Wait for response (routed back via TaskResultHandler) + response_data = await resolver.wait() + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + return ElicitResult.model_validate(response_data) + except anyio.get_cancelled_exc_class(): # pragma: no cover + # Coverage can't track async exception handlers reliably. + # This path is tested in test_elicit_restores_status_on_cancellation + # which verifies status is restored to "working" after cancellation. + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + raise + + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + ) -> ElicitResult: + """ + Send a URL mode elicitation request via the task message queue. + + This directs the user to an external URL for out-of-band interactions + like OAuth flows, credential collection, or payment processing. + + This method: + 1. Checks client capability + 2. Updates task status to "input_required" + 3. Queues the elicitation request + 4. Waits for the response (delivered via tasks/result round-trip) + 5. Updates task status back to "working" + 6. Returns the result + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + + Returns: + The client's response indicating acceptance, decline, or cancellation + + Raises: + McpError: If client doesn't support elicitation capability + RuntimeError: If handler is not configured + """ + self._check_elicitation_capability() + + if self._handler is None: + raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.") + + # Update status to input_required + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + + # Build the request using session's helper + request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage] + message=message, + url=url, + elicitation_id=elicitation_id, + related_task_id=self.task_id, + ) + request_id: RequestId = request.id + + resolver: Resolver[dict[str, Any]] = Resolver() + self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] + + queued = QueuedMessage( + type="request", + message=request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self.task_id, queued) + + try: + # Wait for response (routed back via TaskResultHandler) + response_data = await resolver.wait() + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + return ElicitResult.model_validate(response_data) + except anyio.get_cancelled_exc_class(): # pragma: no cover + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + raise + + async def create_message( + self, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + ) -> CreateMessageResult: + """ + Send a sampling request via the task message queue. + + This method: + 1. Checks client capability + 2. Updates task status to "input_required" + 3. Queues the sampling request + 4. Waits for the response (delivered via tasks/result round-trip) + 5. Updates task status back to "working" + 6. Returns the result + + Args: + messages: The conversation messages for sampling + max_tokens: Maximum tokens in the response + system_prompt: Optional system prompt + include_context: Context inclusion strategy + temperature: Sampling temperature + stop_sequences: Stop sequences + metadata: Additional metadata + model_preferences: Model selection preferences + tools: Optional list of tools the LLM can use during sampling + tool_choice: Optional control over tool usage behavior + + Returns: + The sampling result from the client + + Raises: + McpError: If client doesn't support sampling capability or tools + ValueError: If tool_use or tool_result message structure is invalid + """ + self._check_sampling_capability() + client_caps = self._session.client_params.capabilities if self._session.client_params else None + validate_sampling_tools(client_caps, tools, tool_choice) + validate_tool_use_result_messages(messages) + + if self._handler is None: + raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.") + + # Update status to input_required + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + + # Build the request using session's helper + request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] + messages=messages, + max_tokens=max_tokens, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + related_task_id=self.task_id, + ) + request_id: RequestId = request.id + + resolver: Resolver[dict[str, Any]] = Resolver() + self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] + + queued = QueuedMessage( + type="request", + message=request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self.task_id, queued) + + try: + # Wait for response (routed back via TaskResultHandler) + response_data = await resolver.wait() + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + return CreateMessageResult.model_validate(response_data) + except anyio.get_cancelled_exc_class(): # pragma: no cover + # Coverage can't track async exception handlers reliably. + # This path is tested in test_create_message_restores_status_on_cancellation + # which verifies status is restored to "working" after cancellation. + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + raise + + async def elicit_as_task( + self, + message: str, + requestedSchema: ElicitRequestedSchema, + *, + ttl: int = 60000, + ) -> ElicitResult: + """ + Send a task-augmented elicitation via the queue, then poll client. + + This is for use inside a task-augmented tool call when you want the client + to handle the elicitation as its own task. The elicitation request is queued + and delivered when the client calls tasks/result. After the client responds + with CreateTaskResult, we poll the client's task until complete. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + ttl: Task time-to-live in milliseconds for the client's task + + Returns: + The client's elicitation response + + Raises: + McpError: If client doesn't support task-augmented elicitation + RuntimeError: If handler is not configured + """ + client_caps = self._session.client_params.capabilities if self._session.client_params else None + require_task_augmented_elicitation(client_caps) + + if self._handler is None: + raise RuntimeError("handler is required for elicit_as_task()") + + # Update status to input_required + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + + request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage] + message=message, + requestedSchema=requestedSchema, + related_task_id=self.task_id, + task=TaskMetadata(ttl=ttl), + ) + request_id: RequestId = request.id + + resolver: Resolver[dict[str, Any]] = Resolver() + self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] + + queued = QueuedMessage( + type="request", + message=request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self.task_id, queued) + + try: + # Wait for initial response (CreateTaskResult from client) + response_data = await resolver.wait() + create_result = CreateTaskResult.model_validate(response_data) + client_task_id = create_result.task.taskId + + # Poll the client's task using session.experimental + async for _ in self._session.experimental.poll_task(client_task_id): + pass + + # Get final result from client + result = await self._session.experimental.get_task_result( + client_task_id, + ElicitResult, + ) + + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + return result + + except anyio.get_cancelled_exc_class(): # pragma: no cover + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + raise + + async def create_message_as_task( + self, + messages: list[SamplingMessage], + *, + max_tokens: int, + ttl: int = 60000, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + ) -> CreateMessageResult: + """ + Send a task-augmented sampling request via the queue, then poll client. + + This is for use inside a task-augmented tool call when you want the client + to handle the sampling as its own task. The request is queued and delivered + when the client calls tasks/result. After the client responds with + CreateTaskResult, we poll the client's task until complete. + + Args: + messages: The conversation messages for sampling + max_tokens: Maximum tokens in the response + ttl: Task time-to-live in milliseconds for the client's task + system_prompt: Optional system prompt + include_context: Context inclusion strategy + temperature: Sampling temperature + stop_sequences: Stop sequences + metadata: Additional metadata + model_preferences: Model selection preferences + tools: Optional list of tools the LLM can use during sampling + tool_choice: Optional control over tool usage behavior + + Returns: + The sampling result from the client + + Raises: + McpError: If client doesn't support task-augmented sampling or tools + ValueError: If tool_use or tool_result message structure is invalid + RuntimeError: If handler is not configured + """ + client_caps = self._session.client_params.capabilities if self._session.client_params else None + require_task_augmented_sampling(client_caps) + validate_sampling_tools(client_caps, tools, tool_choice) + validate_tool_use_result_messages(messages) + + if self._handler is None: + raise RuntimeError("handler is required for create_message_as_task()") + + # Update status to input_required + await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED) + + # Build request WITH task field for task-augmented sampling + request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage] + messages=messages, + max_tokens=max_tokens, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + related_task_id=self.task_id, + task=TaskMetadata(ttl=ttl), + ) + request_id: RequestId = request.id + + resolver: Resolver[dict[str, Any]] = Resolver() + self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage] + + queued = QueuedMessage( + type="request", + message=request, + resolver=resolver, + original_request_id=request_id, + ) + await self._queue.enqueue(self.task_id, queued) + + try: + # Wait for initial response (CreateTaskResult from client) + response_data = await resolver.wait() + create_result = CreateTaskResult.model_validate(response_data) + client_task_id = create_result.task.taskId + + # Poll the client's task using session.experimental + async for _ in self._session.experimental.poll_task(client_task_id): + pass + + # Get final result from client + result = await self._session.experimental.get_task_result( + client_task_id, + CreateMessageResult, + ) + + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + return result + + except anyio.get_cancelled_exc_class(): # pragma: no cover + await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING) + raise diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py new file mode 100644 index 0000000000..0b869216e8 --- /dev/null +++ b/src/mcp/server/experimental/task_result_handler.py @@ -0,0 +1,235 @@ +""" +TaskResultHandler - Integrated handler for tasks/result endpoint. + +This implements the dequeue-send-wait pattern from the MCP Tasks spec: +1. Dequeue all pending messages for the task +2. Send them to the client via transport with relatedRequestId routing +3. Wait if task is not in terminal state +4. Return final result when task completes + +This is the core of the task message queue pattern. +""" + +import logging +from typing import Any + +import anyio + +from mcp.server.session import ServerSession +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal +from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.types import ( + INVALID_PARAMS, + ErrorData, + GetTaskPayloadRequest, + GetTaskPayloadResult, + JSONRPCMessage, + RelatedTaskMetadata, + RequestId, +) + +logger = logging.getLogger(__name__) + + +class TaskResultHandler: + """ + Handler for tasks/result that implements the message queue pattern. + + This handler: + 1. Dequeues pending messages (elicitations, notifications) for the task + 2. Sends them to the client via the response stream + 3. Waits for responses and resolves them back to callers + 4. Blocks until task reaches terminal state + 5. Returns the final result + + Usage: + # Create handler with store and queue + handler = TaskResultHandler(task_store, message_queue) + + # Register it with the server + @server.experimental.get_task_result() + async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + ctx = server.request_context + return await handler.handle(req, ctx.session, ctx.request_id) + + # Or use the convenience method + handler.register(server) + """ + + def __init__( + self, + store: TaskStore, + queue: TaskMessageQueue, + ): + self._store = store + self._queue = queue + # Map from internal request ID to resolver for routing responses + self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {} + + async def send_message( + self, + session: ServerSession, + message: SessionMessage, + ) -> None: + """ + Send a message via the session. + + This is a helper for delivering queued task messages. + """ + await session.send_message(message) + + async def handle( + self, + request: GetTaskPayloadRequest, + session: ServerSession, + request_id: RequestId, + ) -> GetTaskPayloadResult: + """ + Handle a tasks/result request. + + This implements the dequeue-send-wait loop: + 1. Dequeue all pending messages + 2. Send each via transport with relatedRequestId = this request's ID + 3. If task not terminal, wait for status change + 4. Loop until task is terminal + 5. Return final result + + Args: + request: The GetTaskPayloadRequest + session: The server session for sending messages + request_id: The request ID for relatedRequestId routing + + Returns: + GetTaskPayloadResult with the task's final payload + """ + task_id = request.params.taskId + + while True: + task = await self._store.get_task(task_id) + if task is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + + await self._deliver_queued_messages(task_id, session, request_id) + + # If task is terminal, return result + if is_terminal(task.status): + result = await self._store.get_result(task_id) + # GetTaskPayloadResult is a Result with extra="allow" + # The stored result contains the actual payload data + # Per spec: tasks/result MUST include _meta with related-task metadata + related_task = RelatedTaskMetadata(taskId=task_id) + related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)} + if result is not None: + result_data = result.model_dump(by_alias=True) + existing_meta: dict[str, Any] = result_data.get("_meta") or {} + result_data["_meta"] = {**existing_meta, **related_task_meta} + return GetTaskPayloadResult.model_validate(result_data) + return GetTaskPayloadResult.model_validate({"_meta": related_task_meta}) + + # Wait for task update (status change or new messages) + await self._wait_for_task_update(task_id) + + async def _deliver_queued_messages( + self, + task_id: str, + session: ServerSession, + request_id: RequestId, + ) -> None: + """ + Dequeue and send all pending messages for a task. + + Each message is sent via the session's write stream with + relatedRequestId set so responses route back to this stream. + """ + while True: + message = await self._queue.dequeue(task_id) + if message is None: + break + + # If this is a request (not notification), wait for response + if message.type == "request" and message.resolver is not None: + # Store the resolver so we can route the response back + original_id = message.original_request_id + if original_id is not None: + self._pending_requests[original_id] = message.resolver + + logger.debug("Delivering queued message for task %s: %s", task_id, message.type) + + # Send the message with relatedRequestId for routing + session_message = SessionMessage( + message=JSONRPCMessage(message.message), + metadata=ServerMessageMetadata(related_request_id=request_id), + ) + await self.send_message(session, session_message) + + async def _wait_for_task_update(self, task_id: str) -> None: + """ + Wait for task to be updated (status change or new message). + + Races between store update and queue message - first one wins. + """ + async with anyio.create_task_group() as tg: + + async def wait_for_store() -> None: + try: + await self._store.wait_for_update(task_id) + except Exception: + pass + finally: + tg.cancel_scope.cancel() + + async def wait_for_queue() -> None: + try: + await self._queue.wait_for_message(task_id) + except Exception: + pass + finally: + tg.cancel_scope.cancel() + + tg.start_soon(wait_for_store) + tg.start_soon(wait_for_queue) + + def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: + """ + Route a response back to the waiting resolver. + + This is called when a response arrives for a queued request. + + Args: + request_id: The request ID from the response + response: The response data + + Returns: + True if response was routed, False if no pending request + """ + resolver = self._pending_requests.pop(request_id, None) + if resolver is not None and not resolver.done(): + resolver.set_result(response) + return True + return False + + def route_error(self, request_id: RequestId, error: ErrorData) -> bool: + """ + Route an error back to the waiting resolver. + + Args: + request_id: The request ID from the error response + error: The error data + + Returns: + True if error was routed, False if no pending request + """ + resolver = self._pending_requests.pop(request_id, None) + if resolver is not None and not resolver.done(): + resolver.set_exception(McpError(error)) + return True + return False diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py new file mode 100644 index 0000000000..dbb2ed6d2b --- /dev/null +++ b/src/mcp/server/experimental/task_support.py @@ -0,0 +1,115 @@ +""" +TaskSupport - Configuration for experimental task support. + +This module provides the TaskSupport class which encapsulates all the +infrastructure needed for task-augmented requests: store, queue, and handler. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field + +import anyio +from anyio.abc import TaskGroup + +from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue +from mcp.shared.experimental.tasks.store import TaskStore + + +@dataclass +class TaskSupport: + """ + Configuration for experimental task support. + + Encapsulates the task store, message queue, result handler, and task group + for spawning background work. + + When enabled on a server, this automatically: + - Configures response routing for each session + - Provides default handlers for task operations + - Manages a task group for background task execution + + Example: + # Simple in-memory setup + server.experimental.enable_tasks() + + # Custom store/queue for distributed systems + server.experimental.enable_tasks( + store=RedisTaskStore(redis_url), + queue=RedisTaskMessageQueue(redis_url), + ) + """ + + store: TaskStore + queue: TaskMessageQueue + handler: TaskResultHandler = field(init=False) + _task_group: TaskGroup | None = field(init=False, default=None) + + def __post_init__(self) -> None: + """Create the result handler from store and queue.""" + self.handler = TaskResultHandler(self.store, self.queue) + + @property + def task_group(self) -> TaskGroup: + """Get the task group for spawning background work. + + Raises: + RuntimeError: If not within a run() context + """ + if self._task_group is None: + raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.") + return self._task_group + + @asynccontextmanager + async def run(self) -> AsyncIterator[None]: + """ + Run the task support lifecycle. + + This creates a task group for spawning background task work. + Called automatically by Server.run(). + + Usage: + async with task_support.run(): + # Task group is now available + ... + """ + async with anyio.create_task_group() as tg: + self._task_group = tg + try: + yield + finally: + self._task_group = None + + def configure_session(self, session: ServerSession) -> None: + """ + Configure a session for task support. + + This registers the result handler as a response router so that + responses to queued requests (elicitation, sampling) are routed + back to the waiting resolvers. + + Called automatically by Server.run() for each new session. + + Args: + session: The session to configure + """ + session.add_response_router(self.handler) + + @classmethod + def in_memory(cls) -> "TaskSupport": + """ + Create in-memory task support. + + Suitable for development, testing, and single-process servers. + For distributed systems, provide custom store and queue implementations. + + Returns: + TaskSupport configured with in-memory store and queue + """ + return cls( + store=InMemoryTaskStore(), + queue=InMemoryTaskMessageQueue(), + ) diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index 4bf4389c15..48c65b57c5 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -94,11 +94,11 @@ def from_function( """ func_name = name or fn.__name__ - if func_name == "": + if func_name == "": # pragma: no cover raise ValueError("You must provide a name for lambda functions") # Find context parameter if it exists - if context_kwarg is None: + if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) # Get schema from func_metadata, excluding context parameter @@ -110,7 +110,7 @@ def from_function( # Convert parameters to PromptArguments arguments: list[PromptArgument] = [] - if "properties" in parameters: + if "properties" in parameters: # pragma: no branch for param_name, param in parameters["properties"].items(): required = param_name in parameters.get("required", []) arguments.append( @@ -172,12 +172,12 @@ async def render( elif isinstance(msg, str): content = TextContent(type="text", text=msg) messages.append(UserMessage(content=content)) - else: + else: # pragma: no cover content = pydantic_core.to_json(msg, fallback=str, indent=2).decode() messages.append(Message(role="user", content=content)) - except Exception: + except Exception: # pragma: no cover raise ValueError(f"Could not convert prompt result to message: {msg}") return messages - except Exception as e: + except Exception as e: # pragma: no cover raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py index 0bef1a2663..557775eab5 100644 --- a/src/mcp/server/fastmcp/resources/base.py +++ b/src/mcp/server/fastmcp/resources/base.py @@ -13,7 +13,7 @@ field_validator, ) -from mcp.types import Icon +from mcp.types import Annotations, Icon class Resource(BaseModel, abc.ABC): @@ -28,9 +28,10 @@ class Resource(BaseModel, abc.ABC): mime_type: str = Field( default="text/plain", description="MIME type of the resource content", - pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$", + pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+(;\s*[a-zA-Z0-9\-_.]+=[a-zA-Z0-9\-_.]+)*$", ) icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this resource") + annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource") @field_validator("name", mode="before") @classmethod @@ -45,4 +46,4 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: @abc.abstractmethod async def read(self) -> str | bytes: """Read the resource content.""" - pass + pass # pragma: no cover diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index b2865def8f..2e7dc171bc 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -10,7 +10,7 @@ from mcp.server.fastmcp.resources.base import Resource from mcp.server.fastmcp.resources.templates import ResourceTemplate from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.types import Icon +from mcp.types import Annotations, Icon if TYPE_CHECKING: from mcp.server.fastmcp.server import Context @@ -63,6 +63,7 @@ def add_template( description: str | None = None, mime_type: str | None = None, icons: list[Icon] | None = None, + annotations: Annotations | None = None, ) -> ResourceTemplate: """Add a template from a function.""" template = ResourceTemplate.from_function( @@ -73,6 +74,7 @@ def add_template( description=description, mime_type=mime_type, icons=icons, + annotations=annotations, ) self._templates[template.uri_template] = template return template @@ -95,7 +97,7 @@ async def get_resource( if params := template.matches(uri_str): try: return await template.create_resource(uri_str, params, context=context) - except Exception as e: + except Exception as e: # pragma: no cover raise ValueError(f"Error creating resource from template: {e}") raise ValueError(f"Unknown resource: {uri}") diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index 8b5af2574c..a98d37f0ac 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -12,7 +12,7 @@ from mcp.server.fastmcp.resources.types import FunctionResource, Resource from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context from mcp.server.fastmcp.utilities.func_metadata import func_metadata -from mcp.types import Icon +from mcp.types import Annotations, Icon if TYPE_CHECKING: from mcp.server.fastmcp.server import Context @@ -29,6 +29,7 @@ class ResourceTemplate(BaseModel): description: str | None = Field(description="Description of what the resource does") mime_type: str = Field(default="text/plain", description="MIME type of the resource content") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for the resource template") + annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource template") fn: Callable[..., Any] = Field(exclude=True) parameters: dict[str, Any] = Field(description="JSON schema for function parameters") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") @@ -43,15 +44,16 @@ def from_function( description: str | None = None, mime_type: str | None = None, icons: list[Icon] | None = None, + annotations: Annotations | None = None, context_kwarg: str | None = None, ) -> ResourceTemplate: """Create a template from a function.""" func_name = name or fn.__name__ if func_name == "": - raise ValueError("You must provide a name for lambda functions") + raise ValueError("You must provide a name for lambda functions") # pragma: no cover # Find context parameter if it exists - if context_kwarg is None: + if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) # Get schema from func_metadata, excluding context parameter @@ -71,6 +73,7 @@ def from_function( description=description or fn.__doc__ or "", mime_type=mime_type or "text/plain", icons=icons, + annotations=annotations, fn=fn, parameters=parameters, context_kwarg=context_kwarg, @@ -108,6 +111,7 @@ async def create_resource( description=self.description, mime_type=self.mime_type, icons=self.icons, + annotations=self.annotations, fn=lambda: result, # Capture result in closure ) except Exception as e: diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index c578e23de3..680e72dc09 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -14,7 +14,7 @@ from pydantic import AnyUrl, Field, ValidationInfo, validate_call from mcp.server.fastmcp.resources.base import Resource -from mcp.types import Icon +from mcp.types import Annotations, Icon class TextResource(Resource): @@ -24,7 +24,7 @@ class TextResource(Resource): async def read(self) -> str: """Read the text content.""" - return self.text + return self.text # pragma: no cover class BinaryResource(Resource): @@ -34,7 +34,7 @@ class BinaryResource(Resource): async def read(self) -> bytes: """Read the binary content.""" - return self.data + return self.data # pragma: no cover class FunctionResource(Resource): @@ -61,7 +61,7 @@ async def read(self) -> str | bytes: if inspect.iscoroutine(result): result = await result - if isinstance(result, Resource): + if isinstance(result, Resource): # pragma: no cover return await result.read() elif isinstance(result, bytes): return result @@ -82,10 +82,11 @@ def from_function( description: str | None = None, mime_type: str | None = None, icons: list[Icon] | None = None, + annotations: Annotations | None = None, ) -> "FunctionResource": """Create a FunctionResource from a function.""" func_name = name or fn.__name__ - if func_name == "": + if func_name == "": # pragma: no cover raise ValueError("You must provide a name for lambda functions") # ensure the arguments are properly cast @@ -99,6 +100,7 @@ def from_function( mime_type=mime_type or "text/plain", fn=fn, icons=icons, + annotations=annotations, ) @@ -120,7 +122,7 @@ class FileResource(Resource): @pydantic.field_validator("path") @classmethod - def validate_absolute_path(cls, path: Path) -> Path: + def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover """Ensure path is absolute.""" if not path.is_absolute(): raise ValueError("Path must be absolute") @@ -153,7 +155,7 @@ class HttpResource(Resource): async def read(self) -> str | bytes: """Read the HTTP content.""" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient() as client: # pragma: no cover response = await client.get(self.url) response.raise_for_status() return response.text @@ -169,13 +171,13 @@ class DirectoryResource(Resource): @pydantic.field_validator("path") @classmethod - def validate_absolute_path(cls, path: Path) -> Path: + def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover """Ensure path is absolute.""" if not path.is_absolute(): raise ValueError("Path must be absolute") return path - def list_files(self) -> list[Path]: + def list_files(self) -> list[Path]: # pragma: no cover """List files in the directory.""" if not self.path.exists(): raise FileNotFoundError(f"Directory not found: {self.path}") @@ -189,7 +191,7 @@ def list_files(self) -> list[Path]: except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") - async def read(self) -> str: # Always returns JSON string + async def read(self) -> str: # Always returns JSON string # pragma: no cover """Read the directory listing.""" try: files = await anyio.to_thread.run_sync(self.list_files) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 485ef15198..51e93b6776 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -4,7 +4,14 @@ import inspect import re -from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence +from collections.abc import ( + AsyncIterator, + Awaitable, + Callable, + Collection, + Iterable, + Sequence, +) from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, Generic, Literal @@ -22,10 +29,25 @@ from starlette.types import Receive, Scope, Send from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier +from mcp.server.auth.middleware.bearer_auth import ( + BearerAuthBackend, + RequireAuthMiddleware, +) +from mcp.server.auth.provider import ( + OAuthAuthorizationServerProvider, + ProviderTokenVerifier, + TokenVerifier, +) from mcp.server.auth.settings import AuthSettings -from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation +from mcp.server.elicitation import ( + ElicitationResult, + ElicitSchemaModelT, + UrlElicitationResult, + elicit_with_validation, +) +from mcp.server.elicitation import ( + elicit_url as _elicit_url, +) from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -43,7 +65,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext, RequestT -from mcp.types import AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -112,7 +134,9 @@ def lifespan_wrapper( lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], ) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async def wrap( + _: MCPServer[LifespanResultT, Request], + ) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -126,9 +150,10 @@ def __init__( # noqa: PLR0913 instructions: str | None = None, website_url: str | None = None, icons: list[Icon] | None = None, - auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, + auth_server_provider: (OAuthAuthorizationServerProvider[Any, Any, Any] | None) = None, token_verifier: TokenVerifier | None = None, event_store: EventStore | None = None, + retry_interval: int | None = None, *, tools: list[Tool] | None = None, debug: bool = False, @@ -145,10 +170,18 @@ def __init__( # noqa: PLR0913 warn_on_duplicate_tools: bool = True, warn_on_duplicate_prompts: bool = True, dependencies: Collection[str] = (), - lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, + lifespan: (Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None) = None, auth: AuthSettings | None = None, transport_security: TransportSecuritySettings | None = None, ): + # Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6) + if transport_security is None and host in ("127.0.0.1", "localhost", "::1"): + transport_security = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["127.0.0.1:*", "localhost:*", "[::1]:*"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://[::1]:*"], + ) + self.settings = Settings( debug=debug, log_level=log_level, @@ -183,20 +216,21 @@ def __init__( # noqa: PLR0913 self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) # Validate auth configuration if self.settings.auth is not None: - if auth_server_provider and token_verifier: + if auth_server_provider and token_verifier: # pragma: no cover raise ValueError("Cannot specify both auth_server_provider and token_verifier") - if not auth_server_provider and not token_verifier: + if not auth_server_provider and not token_verifier: # pragma: no cover raise ValueError("Must specify either auth_server_provider or token_verifier when auth is enabled") - elif auth_server_provider or token_verifier: + elif auth_server_provider or token_verifier: # pragma: no cover raise ValueError("Cannot specify auth_server_provider or token_verifier without auth settings") self._auth_server_provider = auth_server_provider self._token_verifier = token_verifier # Create token verifier from provider if needed (backwards compatibility) - if auth_server_provider and not token_verifier: + if auth_server_provider and not token_verifier: # pragma: no cover self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._event_store = event_store + self._retry_interval = retry_interval self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies self._session_manager: StreamableHTTPSessionManager | None = None @@ -233,14 +267,14 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - if self._session_manager is None: + if self._session_manager is None: # pragma: no cover raise RuntimeError( "Session manager can only be accessed after" "calling streamable_http_app()." "The session manager is created lazily" "to avoid unnecessary initialization." ) - return self._session_manager + return self._session_manager # pragma: no cover def run( self, @@ -254,15 +288,15 @@ def run( mount_path: Optional mount path for SSE transport """ TRANSPORTS = Literal["stdio", "sse", "streamable-http"] - if transport not in TRANSPORTS.__args__: # type: ignore + if transport not in TRANSPORTS.__args__: # type: ignore # pragma: no cover raise ValueError(f"Unknown transport: {transport}") match transport: case "stdio": anyio.run(self.run_stdio_async) - case "sse": + case "sse": # pragma: no cover anyio.run(lambda: self.run_sse_async(mount_path)) - case "streamable-http": + case "streamable-http": # pragma: no cover anyio.run(self.run_streamable_http_async) def _setup_handlers(self) -> None: @@ -290,6 +324,7 @@ async def list_tools(self) -> list[MCPTool]: outputSchema=info.output_schema, annotations=info.annotations, icons=info.icons, + _meta=info.meta, ) for info in tools ] @@ -322,6 +357,7 @@ async def list_resources(self) -> list[MCPResource]: description=resource.description, mimeType=resource.mime_type, icons=resource.icons, + annotations=resource.annotations, ) for resource in resources ] @@ -336,6 +372,7 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: description=template.description, mimeType=template.mime_type, icons=template.icons, + annotations=template.annotations, ) for template in templates ] @@ -345,13 +382,13 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent context = self.get_context() resource = await self._resource_manager.get_resource(uri, context=context) - if not resource: + if not resource: # pragma: no cover raise ResourceError(f"Unknown resource: {uri}") try: content = await resource.read() return [ReadResourceContents(content=content, mime_type=resource.mime_type)] - except Exception as e: + except Exception as e: # pragma: no cover logger.exception(f"Error reading resource {uri}") raise ResourceError(str(e)) @@ -363,6 +400,7 @@ def add_tool( description: str | None = None, annotations: ToolAnnotations | None = None, icons: list[Icon] | None = None, + meta: dict[str, Any] | None = None, structured_output: bool | None = None, ) -> None: """Add a tool to the server. @@ -388,6 +426,7 @@ def add_tool( description=description, annotations=annotations, icons=icons, + meta=meta, structured_output=structured_output, ) @@ -409,6 +448,7 @@ def tool( description: str | None = None, annotations: ToolAnnotations | None = None, icons: list[Icon] | None = None, + meta: dict[str, Any] | None = None, structured_output: bool | None = None, ) -> Callable[[AnyFunction], AnyFunction]: """Decorator to register a tool. @@ -456,6 +496,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: description=description, annotations=annotations, icons=icons, + meta=meta, structured_output=structured_output, ) return fn @@ -497,6 +538,7 @@ def resource( description: str | None = None, mime_type: str | None = None, icons: list[Icon] | None = None, + annotations: Annotations | None = None, ) -> Callable[[AnyFunction], AnyFunction]: """Decorator to register a function as a resource. @@ -572,6 +614,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: description=description, mime_type=mime_type, icons=icons, + annotations=annotations, ) else: # Register as regular resource @@ -583,6 +626,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: description=description, mime_type=mime_type, icons=icons, + annotations=annotations, ) self.add_resource(resource) return fn @@ -667,6 +711,10 @@ def custom_route( The handler function must be an async function that accepts a Starlette Request and returns a Response. + Routes using this decorator will not require authorization. It is intended + for uses that are either a part of authorization flows or intended to be + public such as health check endpoints. + Args: path: URL path for the route (e.g., "/oauth/callback") methods: List of HTTP methods to support (e.g., ["GET", "POST"]) @@ -680,7 +728,7 @@ async def health_check(request: Request) -> Response: return JSONResponse({"status": "ok"}) """ - def decorator( + def decorator( # pragma: no cover func: Callable[[Request], Awaitable[Response]], ) -> Callable[[Request], Awaitable[Response]]: self._custom_starlette_routes.append( @@ -694,7 +742,7 @@ def decorator( ) return func - return decorator + return decorator # pragma: no cover async def run_stdio_async(self) -> None: """Run the server using stdio transport.""" @@ -705,7 +753,7 @@ async def run_stdio_async(self) -> None: self._mcp_server.create_initialization_options(), ) - async def run_sse_async(self, mount_path: str | None = None) -> None: + async def run_sse_async(self, mount_path: str | None = None) -> None: # pragma: no cover """Run the server using SSE transport.""" import uvicorn @@ -720,7 +768,7 @@ async def run_sse_async(self, mount_path: str | None = None) -> None: server = uvicorn.Server(config) await server.serve() - async def run_streamable_http_async(self) -> None: + async def run_streamable_http_async(self) -> None: # pragma: no cover """Run the server using StreamableHTTP transport.""" import uvicorn @@ -780,7 +828,7 @@ def sse_app(self, mount_path: str | None = None) -> Starlette: security_settings=self.settings.transport_security, ) - async def handle_sse(scope: Scope, receive: Receive, send: Send): + async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no cover # Add client ID from auth context into request context if available async with sse.connect_sse( @@ -801,7 +849,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): required_scopes = [] # Set up auth if configured - if self.settings.auth: + if self.settings.auth: # pragma: no cover required_scopes = self.settings.auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -832,7 +880,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): ) # When auth is configured, require authentication - if self._token_verifier: + if self._token_verifier: # pragma: no cover # Determine resource metadata URL resource_metadata_url = None if self.settings.auth and self.settings.auth.resource_server_url: @@ -855,7 +903,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): app=RequireAuthMiddleware(sse.handle_post_message, required_scopes, resource_metadata_url), ) ) - else: + else: # pragma: no cover # Auth is disabled, no need for RequireAuthMiddleware # Since handle_sse is an ASGI app, we need to create a compatible endpoint async def sse_endpoint(request: Request) -> Response: @@ -876,7 +924,7 @@ async def sse_endpoint(request: Request) -> Response: ) ) # Add protected resource metadata endpoint if configured as RS - if self.settings.auth and self.settings.auth.resource_server_url: + if self.settings.auth and self.settings.auth.resource_server_url: # pragma: no cover from mcp.server.auth.routes import create_protected_resource_routes routes.extend( @@ -898,10 +946,11 @@ def streamable_http_app(self) -> Starlette: from starlette.middleware import Middleware # Create session manager on first call (lazy initialization) - if self._session_manager is None: + if self._session_manager is None: # pragma: no branch self._session_manager = StreamableHTTPSessionManager( app=self._mcp_server, event_store=self._event_store, + retry_interval=self._retry_interval, json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting security_settings=self.settings.transport_security, @@ -916,7 +965,7 @@ def streamable_http_app(self) -> Starlette: required_scopes = [] # Set up auth if configured - if self.settings.auth: + if self.settings.auth: # pragma: no cover required_scopes = self.settings.auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -944,7 +993,7 @@ def streamable_http_app(self) -> Starlette: ) # Set up routes with or without auth - if self._token_verifier: + if self._token_verifier: # pragma: no cover # Determine resource metadata URL resource_metadata_url = None if self.settings.auth and self.settings.auth.resource_server_url: @@ -969,7 +1018,7 @@ def streamable_http_app(self) -> Starlette: ) # Add protected resource metadata endpoint if configured as RS - if self.settings.auth and self.settings.auth.resource_server_url: + if self.settings.auth and self.settings.auth.resource_server_url: # pragma: no cover from mcp.server.auth.routes import create_protected_resource_routes routes.extend( @@ -1036,7 +1085,7 @@ class StreamableHTTPASGIApp: def __init__(self, session_manager: StreamableHTTPSessionManager): self.session_manager = session_manager - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover await self.session_manager.handle_request(scope, receive, send) @@ -1091,16 +1140,16 @@ def __init__( @property def fastmcp(self) -> FastMCP: """Access to the FastMCP server.""" - if self._fastmcp is None: + if self._fastmcp is None: # pragma: no cover raise ValueError("Context is not available outside of a request") - return self._fastmcp + return self._fastmcp # pragma: no cover @property def request_context( self, ) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]: """Access to the underlying request context.""" - if self._request_context is None: + if self._request_context is None: # pragma: no cover raise ValueError("Context is not available outside of a request") return self._request_context @@ -1114,7 +1163,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes """ progress_token = self.request_context.meta.progressToken if self.request_context.meta else None - if progress_token is None: + if progress_token is None: # pragma: no cover return await self.request_context.session.send_progress_notification( @@ -1164,7 +1213,45 @@ async def elicit( """ return await elicit_with_validation( - session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id + session=self.request_context.session, + message=message, + schema=schema, + related_request_id=self.request_id, + ) + + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + ) -> UrlElicitationResult: + """Request URL mode elicitation from the client. + + This directs the user to an external URL for out-of-band interactions + that must not pass through the MCP client. Use this for: + - Collecting sensitive credentials (API keys, passwords) + - OAuth authorization flows with third-party services + - Payment and subscription flows + - Any interaction where data should not pass through the LLM context + + The response indicates whether the user consented to navigate to the URL. + The actual interaction happens out-of-band. When the elicitation completes, + call `self.session.send_elicit_complete(elicitation_id)` to notify the client. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + + Returns: + UrlElicitationResult indicating accept, decline, or cancel + """ + return await _elicit_url( + session=self.request_context.session, + message=message, + url=url, + elicitation_id=elicitation_id, + related_request_id=self.request_id, ) async def log( @@ -1173,6 +1260,7 @@ async def log( message: str, *, logger_name: str | None = None, + extra: dict[str, Any] | None = None, ) -> None: """Send a log message to the client. @@ -1180,11 +1268,20 @@ async def log( level: Log level (debug, info, warning, error) message: Log message logger_name: Optional logger name - **extra: Additional structured data to include + extra: Optional dictionary with additional structured data to include """ + + if extra: + log_data = { + "message": message, + **extra, + } + else: + log_data = message + await self.request_context.session.send_log_message( level=level, - data=message, + data=log_data, logger=logger_name, related_request_id=self.request_id, ) @@ -1192,7 +1289,9 @@ async def log( @property def client_id(self) -> str | None: """Get the client ID if available.""" - return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None + return ( + getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None + ) # pragma: no cover @property def request_id(self) -> str: @@ -1204,19 +1303,53 @@ def session(self): """Access to the underlying session for advanced usage.""" return self.request_context.session + async def close_sse_stream(self) -> None: + """Close the SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the current request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + The callback is only available when event_store is configured. + """ + if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + await self._request_context.close_sse_stream() + + async def close_standalone_sse_stream(self) -> None: + """Close the standalone GET SSE stream to trigger client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Note: + This is a no-op if not using StreamableHTTP transport with event_store. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap. + """ + if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover + await self._request_context.close_standalone_sse_stream() + # Convenience methods for common log levels - async def debug(self, message: str, **extra: Any) -> None: + async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: """Send a debug log message.""" - await self.log("debug", message, **extra) + await self.log("debug", message, logger_name=logger_name, extra=extra) - async def info(self, message: str, **extra: Any) -> None: + async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: """Send an info log message.""" - await self.log("info", message, **extra) + await self.log("info", message, logger_name=logger_name, extra=extra) - async def warning(self, message: str, **extra: Any) -> None: + async def warning( + self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None + ) -> None: """Send a warning log message.""" - await self.log("warning", message, **extra) + await self.log("warning", message, logger_name=logger_name, extra=extra) - async def error(self, message: str, **extra: Any) -> None: + async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: """Send an error log message.""" - await self.log("error", message, **extra) + await self.log("error", message, logger_name=logger_name, extra=extra) diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 3f26ddcea6..1ae6d90d19 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -11,6 +11,8 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.context_injection import find_context_parameter from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations if TYPE_CHECKING: @@ -34,6 +36,7 @@ class Tool(BaseModel): context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool") + meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool") @cached_property def output_schema(self) -> dict[str, Any] | None: @@ -49,18 +52,21 @@ def from_function( context_kwarg: str | None = None, annotations: ToolAnnotations | None = None, icons: list[Icon] | None = None, + meta: dict[str, Any] | None = None, structured_output: bool | None = None, ) -> Tool: """Create a Tool from a function.""" func_name = name or fn.__name__ + validate_and_warn_tool_name(func_name) + if func_name == "": raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" is_async = _is_async_callable(fn) - if context_kwarg is None: + if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) func_arg_metadata = func_metadata( @@ -81,6 +87,7 @@ def from_function( context_kwarg=context_kwarg, annotations=annotations, icons=icons, + meta=meta, ) async def run( @@ -102,12 +109,16 @@ async def run( result = self.fn_metadata.convert_result(result) return result + except UrlElicitationRequiredError: + # Re-raise UrlElicitationRequiredError so it can be properly handled + # as an MCP error response with code -32042 + raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): + while isinstance(obj, functools.partial): # pragma: no cover obj = obj.func return inspect.iscoroutinefunction(obj) or ( diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index d6c0054af7..095753de69 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -50,6 +50,7 @@ def add_tool( description: str | None = None, annotations: ToolAnnotations | None = None, icons: list[Icon] | None = None, + meta: dict[str, Any] | None = None, structured_output: bool | None = None, ) -> Tool: """Add a tool to the server.""" @@ -60,6 +61,7 @@ def add_tool( description=description, annotations=annotations, icons=icons, + meta=meta, structured_output=structured_output, ) existing = self._tools.get(tool.name) diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 3289a5aa62..fa443d2fcb 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable, Sequence from itertools import chain from types import GenericAlias -from typing import Annotated, Any, ForwardRef, cast, get_args, get_origin, get_type_hints +from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints import pydantic_core from pydantic import ( @@ -14,15 +14,21 @@ WithJsonSchema, create_model, ) -from pydantic._internal._typing_extra import eval_type_backport from pydantic.fields import FieldInfo from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind -from pydantic_core import PydanticUndefined +from typing_extensions import is_typeddict +from typing_inspection.introspection import ( + UNKNOWN, + AnnotationSource, + ForbiddenQualifier, + inspect_annotation, + is_union_origin, +) from mcp.server.fastmcp.exceptions import InvalidSignature from mcp.server.fastmcp.utilities.logging import get_logger from mcp.server.fastmcp.utilities.types import Audio, Image -from mcp.types import ContentBlock, TextContent +from mcp.types import CallToolResult, ContentBlock, TextContent logger = get_logger(__name__) @@ -104,6 +110,12 @@ def convert_result(self, result: Any) -> Any: from function return values, whereas the lowlevel server simply serializes the structured output. """ + if isinstance(result, CallToolResult): + if self.output_schema is not None: + assert self.output_model is not None, "Output model must be set if output schema is defined" + self.output_model.model_validate(result.structuredContent) + return result + unstructured_content = _convert_to_content(result) if self.output_schema is None: @@ -140,7 +152,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: key_to_field_info[field_info.alias] = field_info for data_key, data_value in data.items(): - if data_key not in key_to_field_info: + if data_key not in key_to_field_info: # pragma: no cover continue field_info = key_to_field_info[data_key] @@ -205,56 +217,47 @@ def func_metadata( - output_model: A pydantic model for the return type if output is structured - output_conversion: Records how function output should be converted before returning. """ - sig = _get_typed_signature(func) + try: + sig = inspect.signature(func, eval_str=True) + except NameError as e: # pragma: no cover + # This raise could perhaps be skipped, and we (FastMCP) just call + # model_rebuild right before using it 🤷 + raise InvalidSignature(f"Unable to evaluate type annotations for callable {func.__name__!r}") from e params = sig.parameters dynamic_pydantic_model_params: dict[str, Any] = {} - globalns = getattr(func, "__globals__", {}) for param in params.values(): - if param.name.startswith("_"): + if param.name.startswith("_"): # pragma: no cover raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'") if param.name in skip_names: continue - annotation = param.annotation - - # `x: None` / `x: None = None` - if annotation is None: - annotation = Annotated[ - None, - Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined), - ] - - # Untyped field - if annotation is inspect.Parameter.empty: - annotation = Annotated[ - Any, - Field(), - # 🤷 - WithJsonSchema({"title": param.name, "type": "string"}), - ] - - field_info = FieldInfo.from_annotated_attribute( - _get_typed_annotation(annotation, globalns), - param.default if param.default is not inspect.Parameter.empty else PydanticUndefined, - ) + annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any + field_name = param.name + field_kwargs: dict[str, Any] = {} + field_metadata: list[Any] = [] + + if param.annotation is inspect.Parameter.empty: + field_metadata.append(WithJsonSchema({"title": param.name, "type": "string"})) # Check if the parameter name conflicts with BaseModel attributes # This is necessary because Pydantic warns about shadowing parent attributes - if hasattr(BaseModel, param.name) and callable(getattr(BaseModel, param.name)): + if hasattr(BaseModel, field_name) and callable(getattr(BaseModel, field_name)): # Use an alias to avoid the shadowing warning - field_info.alias = param.name - field_info.validation_alias = param.name - field_info.serialization_alias = param.name - # Use a prefixed internal name - internal_name = f"field_{param.name}" - dynamic_pydantic_model_params[internal_name] = (field_info.annotation, field_info) + field_kwargs["alias"] = field_name + # Use a prefixed field name + field_name = f"field_{field_name}" + + if param.default is not inspect.Parameter.empty: + dynamic_pydantic_model_params[field_name] = ( + Annotated[(annotation, *field_metadata, Field(**field_kwargs))], + param.default, + ) else: - dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) - continue + dynamic_pydantic_model_params[field_name] = Annotated[(annotation, *field_metadata, Field(**field_kwargs))] arguments_model = create_model( f"{func.__name__}Arguments", - **dynamic_pydantic_model_params, __base__=ArgModelBase, + **dynamic_pydantic_model_params, ) if structured_output is False: @@ -265,15 +268,56 @@ def func_metadata( if sig.return_annotation is inspect.Parameter.empty and structured_output is True: raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output") - output_info = FieldInfo.from_annotation(_get_typed_annotation(sig.return_annotation, globalns)) - annotation = output_info.annotation + try: + inspected_return_ann = inspect_annotation(sig.return_annotation, annotation_source=AnnotationSource.FUNCTION) + except ForbiddenQualifier as e: + raise InvalidSignature(f"Function {func.__name__}: return annotation contains an invalid type qualifier") from e + + return_type_expr = inspected_return_ann.type - output_model, output_schema, wrap_output = _try_create_model_and_schema(annotation, func.__name__, output_info) + # `AnnotationSource.FUNCTION` allows no type qualifier to be used, so `return_type_expr` is guaranteed to *not* be + # unknown (i.e. a bare `Final`). + assert return_type_expr is not UNKNOWN + + if is_union_origin(get_origin(return_type_expr)): + args = get_args(return_type_expr) + # Check if CallToolResult appears in the union (excluding None for Optional check) + if any(isinstance(arg, type) and issubclass(arg, CallToolResult) for arg in args if arg is not type(None)): + raise InvalidSignature( + f"Function {func.__name__}: CallToolResult cannot be used in Union or Optional types. " + "To return empty results, use: CallToolResult(content=[])" + ) + + original_annotation: Any + # if the typehint is CallToolResult, the user either intends to return without validation + # or they provided validation as Annotated metadata + if isinstance(return_type_expr, type) and issubclass(return_type_expr, CallToolResult): + if inspected_return_ann.metadata: + return_type_expr = inspected_return_ann.metadata[0] + if len(inspected_return_ann.metadata) >= 2: + # Reconstruct the original annotation, by preserving the remaining metadata, + # i.e. from `Annotated[CallToolResult, ReturnType, Gt(1)]` to + # `Annotated[ReturnType, Gt(1)]`: + original_annotation = Annotated[ + (return_type_expr, *inspected_return_ann.metadata[1:]) + ] # pragma: no cover + else: + # We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation + # as beging `ReturnType`: + original_annotation = return_type_expr + else: + return FuncMetadata(arg_model=arguments_model) + else: + original_annotation = sig.return_annotation + + output_model, output_schema, wrap_output = _try_create_model_and_schema( + original_annotation, return_type_expr, func.__name__ + ) if output_model is None and structured_output is True: # Model creation failed or produced warnings - no structured output raise InvalidSignature( - f"Function {func.__name__}: return type {annotation} is not serializable for structured output" + f"Function {func.__name__}: return type {return_type_expr} is not serializable for structured output" ) return FuncMetadata( @@ -285,10 +329,18 @@ def func_metadata( def _try_create_model_and_schema( - annotation: Any, func_name: str, field_info: FieldInfo + original_annotation: Any, + type_expr: Any, + func_name: str, ) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]: """Try to create a model and schema for the given annotation without warnings. + Args: + original_annotation: The original return annotation (may be wrapped in `Annotated`). + type_expr: The underlying type expression derived from the return annotation + (`Annotated` and type qualifiers were stripped). + func_name: The name of the function. + Returns: tuple of (model or None, schema or None, wrap_output) Model and schema are None if warnings occur or creation fails. @@ -298,43 +350,45 @@ def _try_create_model_and_schema( wrap_output = False # First handle special case: None - if annotation is None: - model = _create_wrapped_model(func_name, annotation, field_info) + if type_expr is None: + model = _create_wrapped_model(func_name, original_annotation) wrap_output = True # Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.) - elif isinstance(annotation, GenericAlias): - origin = get_origin(annotation) + elif isinstance(type_expr, GenericAlias): + origin = get_origin(type_expr) # Special case: dict with string keys can use RootModel if origin is dict: - args = get_args(annotation) + args = get_args(type_expr) if len(args) == 2 and args[0] is str: - model = _create_dict_model(func_name, annotation) + # TODO: should we use the original annotation? We are loosing any potential `Annotated` + # metadata for Pydantic here: + model = _create_dict_model(func_name, type_expr) else: # dict with non-str keys needs wrapping - model = _create_wrapped_model(func_name, annotation, field_info) + model = _create_wrapped_model(func_name, original_annotation) wrap_output = True else: # All other generic types need wrapping (list, tuple, Union, Optional, etc.) - model = _create_wrapped_model(func_name, annotation, field_info) + model = _create_wrapped_model(func_name, original_annotation) wrap_output = True # Handle regular type objects - elif isinstance(annotation, type): - type_annotation: type[Any] = cast(type[Any], annotation) + elif isinstance(type_expr, type): + type_annotation = cast(type[Any], type_expr) # Case 1: BaseModel subclasses (can be used directly) - if issubclass(annotation, BaseModel): - model = annotation + if issubclass(type_annotation, BaseModel): + model = type_annotation - # Case 2: TypedDict (special dict subclass with __annotations__) - elif hasattr(type_annotation, "__annotations__") and issubclass(annotation, dict): + # Case 2: TypedDicts: + elif is_typeddict(type_annotation): model = _create_model_from_typeddict(type_annotation) # Case 3: Primitive types that need wrapping - elif annotation in (str, int, float, bool, bytes, type(None)): - model = _create_wrapped_model(func_name, annotation, field_info) + elif type_annotation in (str, int, float, bool, bytes, type(None)): + model = _create_wrapped_model(func_name, original_annotation) wrap_output = True # Case 4: Other class types (dataclasses, regular classes with annotations) @@ -342,14 +396,14 @@ def _try_create_model_and_schema( type_hints = get_type_hints(type_annotation) if type_hints: # Classes with type hints can be converted to Pydantic models - model = _create_model_from_class(type_annotation) + model = _create_model_from_class(type_annotation, type_hints) # Classes without type hints are not serializable - model remains None # Handle any other types not covered above else: # This includes typing constructs that aren't GenericAlias in Python 3.10 # (e.g., Union, Optional in some Python versions) - model = _create_wrapped_model(func_name, annotation, field_info) + model = _create_wrapped_model(func_name, original_annotation) wrap_output = True if model: @@ -363,7 +417,7 @@ def _try_create_model_and_schema( # ValueError: When there are issues with the type definition (including our custom warnings) # SchemaError: When Pydantic can't build a schema # ValidationError: When validation fails - logger.info(f"Cannot create schema for type {annotation} in {func_name}: {type(e).__name__}: {e}") + logger.info(f"Cannot create schema for type {type_expr} in {func_name}: {type(e).__name__}: {e}") return None, None, False return model, schema, wrap_output @@ -371,7 +425,10 @@ def _try_create_model_and_schema( return None, None, False -def _create_model_from_class(cls: type[Any]) -> type[BaseModel]: +_no_default = object() + + +def _create_model_from_class(cls: type[Any], type_hints: dict[str, Any]) -> type[BaseModel]: """Create a Pydantic model from an ordinary class. The created model will: @@ -379,24 +436,20 @@ def _create_model_from_class(cls: type[Any]) -> type[BaseModel]: - Have fields with the same names and types as the class's fields - Include all fields whose type does not include None in the set of required fields - Precondition: cls must have type hints (i.e., get_type_hints(cls) is non-empty) + Precondition: cls must have type hints (i.e., `type_hints` is non-empty) """ - type_hints = get_type_hints(cls) - model_fields: dict[str, Any] = {} for field_name, field_type in type_hints.items(): - if field_name.startswith("_"): + if field_name.startswith("_"): # pragma: no cover continue - default = getattr(cls, field_name, PydanticUndefined) - field_info = FieldInfo.from_annotated_attribute(field_type, default) - model_fields[field_name] = (field_info.annotation, field_info) - - # Create a base class with the config - class BaseWithConfig(BaseModel): - model_config = ConfigDict(from_attributes=True) + default = getattr(cls, field_name, _no_default) + if default is _no_default: + model_fields[field_name] = field_type + else: + model_fields[field_name] = (field_type, default) - return create_model(cls.__name__, **model_fields, __base__=BaseWithConfig) + return create_model(cls.__name__, __config__=ConfigDict(from_attributes=True), **model_fields) def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]: @@ -409,31 +462,25 @@ def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]: model_fields: dict[str, Any] = {} for field_name, field_type in type_hints.items(): - field_info = FieldInfo.from_annotation(field_type) - if field_name not in required_keys: # For optional TypedDict fields, set default=None # This makes them not required in the Pydantic model # The model should use exclude_unset=True when dumping to get TypedDict semantics - field_info.default = None - - model_fields[field_name] = (field_info.annotation, field_info) + model_fields[field_name] = (field_type, None) + else: + model_fields[field_name] = field_type - return create_model(td_type.__name__, **model_fields, __base__=BaseModel) + return create_model(td_type.__name__, **model_fields) -def _create_wrapped_model(func_name: str, annotation: Any, field_info: FieldInfo) -> type[BaseModel]: +def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]: """Create a model that wraps a type in a 'result' field. This is used for primitive types, generic types like list/dict, etc. """ model_name = f"{func_name}Output" - # Pydantic needs type(None) instead of None for the type annotation - if annotation is None: - annotation = type(None) - - return create_model(model_name, result=(annotation, field_info), __base__=BaseModel) + return create_model(model_name, result=annotation) def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: @@ -449,43 +496,6 @@ class DictModel(RootModel[dict_annotation]): return DictModel -def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: - def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]: - try: - return eval_type_backport(value, globalns, localns), True - except NameError: - return value, False - - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - annotation, status = try_eval_type(annotation, globalns, globalns) - - # This check and raise could perhaps be skipped, and we (FastMCP) just call - # model_rebuild right before using it 🤷 - if status is False: - raise InvalidSignature(f"Unable to evaluate type annotation {annotation}") - - return annotation - - -def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: - """Get function signature while evaluating forward references""" - signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=param.default, - annotation=_get_typed_annotation(param.annotation, globalns), - ) - for param in signature.parameters.values() - ] - typed_return = _get_typed_annotation(signature.return_annotation, globalns) - typed_signature = inspect.Signature(typed_params, return_annotation=typed_return) - return typed_signature - - def _convert_to_content( result: Any, ) -> Sequence[ContentBlock]: @@ -497,7 +507,7 @@ def _convert_to_content( output than the lowlevel server tool call handler, which just serializes structured content verbatim. """ - if result is None: + if result is None: # pragma: no cover return [] if isinstance(result, ContentBlock): diff --git a/src/mcp/server/fastmcp/utilities/logging.py b/src/mcp/server/fastmcp/utilities/logging.py index 091d57e69d..4b47d3b882 100644 --- a/src/mcp/server/fastmcp/utilities/logging.py +++ b/src/mcp/server/fastmcp/utilities/logging.py @@ -25,15 +25,15 @@ def configure_logging( level: the log level to use """ handlers: list[logging.Handler] = [] - try: + try: # pragma: no cover from rich.console import Console from rich.logging import RichHandler handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) - except ImportError: + except ImportError: # pragma: no cover pass - if not handlers: + if not handlers: # pragma: no cover handlers.append(logging.StreamHandler()) logging.basicConfig( diff --git a/src/mcp/server/fastmcp/utilities/types.py b/src/mcp/server/fastmcp/utilities/types.py index 1be6f82748..d6928ca3f8 100644 --- a/src/mcp/server/fastmcp/utilities/types.py +++ b/src/mcp/server/fastmcp/utilities/types.py @@ -15,9 +15,9 @@ def __init__( data: bytes | None = None, format: str | None = None, ): - if path is None and data is None: + if path is None and data is None: # pragma: no cover raise ValueError("Either path or data must be provided") - if path is not None and data is not None: + if path is not None and data is not None: # pragma: no cover raise ValueError("Only one of path or data can be provided") self.path = Path(path) if path else None @@ -27,7 +27,7 @@ def __init__( def _get_mime_type(self) -> str: """Get MIME type from format or guess from file extension.""" - if self._format: + if self._format: # pragma: no cover return f"image/{self._format.lower()}" if self.path: @@ -39,16 +39,16 @@ def _get_mime_type(self) -> str: ".gif": "image/gif", ".webp": "image/webp", }.get(suffix, "application/octet-stream") - return "image/png" # default for raw binary data + return "image/png" # pragma: no cover # default for raw binary data def to_image_content(self) -> ImageContent: """Convert to MCP ImageContent.""" if self.path: with open(self.path, "rb") as f: data = base64.b64encode(f.read()).decode() - elif self.data is not None: + elif self.data is not None: # pragma: no cover data = base64.b64encode(self.data).decode() - else: + else: # pragma: no cover raise ValueError("No image data available") return ImageContent(type="image", data=data, mimeType=self._mime_type) @@ -63,7 +63,7 @@ def __init__( data: bytes | None = None, format: str | None = None, ): - if not bool(path) ^ bool(data): + if not bool(path) ^ bool(data): # pragma: no cover raise ValueError("Either path or data can be provided") self.path = Path(path) if path else None @@ -73,7 +73,7 @@ def __init__( def _get_mime_type(self) -> str: """Get MIME type from format or guess from file extension.""" - if self._format: + if self._format: # pragma: no cover return f"audio/{self._format.lower()}" if self.path: @@ -86,16 +86,16 @@ def _get_mime_type(self) -> str: ".aac": "audio/aac", ".m4a": "audio/mp4", }.get(suffix, "application/octet-stream") - return "audio/wav" # default for raw binary data + return "audio/wav" # pragma: no cover # default for raw binary data def to_audio_content(self) -> AudioContent: """Convert to MCP AudioContent.""" if self.path: with open(self.path, "rb") as f: data = base64.b64encode(f.read()).decode() - elif self.data is not None: + elif self.data is not None: # pragma: no cover data = base64.b64encode(self.data).decode() - else: + else: # pragma: no cover raise ValueError("No audio data available") return AudioContent(type="audio", data=data, mimeType=self._mime_type) diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py new file mode 100644 index 0000000000..0e6655b3de --- /dev/null +++ b/src/mcp/server/lowlevel/experimental.py @@ -0,0 +1,288 @@ +"""Experimental handlers for the low-level MCP server. + +WARNING: These APIs are experimental and may change without notice. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING + +from mcp.server.experimental.task_support import TaskSupport +from mcp.server.lowlevel.func_inspection import create_call_wrapper +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.helpers import cancel_task +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + INVALID_PARAMS, + CancelTaskRequest, + CancelTaskResult, + ErrorData, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerCapabilities, + ServerResult, + ServerTasksCapability, + ServerTasksRequestsCapability, + TasksCancelCapability, + TasksListCapability, + TasksToolsCapability, +) + +if TYPE_CHECKING: + from mcp.server.lowlevel.server import Server + +logger = logging.getLogger(__name__) + + +class ExperimentalHandlers: + """Experimental request/notification handlers. + + WARNING: These APIs are experimental and may change without notice. + """ + + def __init__( + self, + server: Server, + request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], + notification_handlers: dict[type, Callable[..., Awaitable[None]]], + ): + self._server = server + self._request_handlers = request_handlers + self._notification_handlers = notification_handlers + self._task_support: TaskSupport | None = None + + @property + def task_support(self) -> TaskSupport | None: + """Get the task support configuration, if enabled.""" + return self._task_support + + def update_capabilities(self, capabilities: ServerCapabilities) -> None: + # Only add tasks capability if handlers are registered + if not any( + req_type in self._request_handlers + for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest] + ): + return + + capabilities.tasks = ServerTasksCapability() + if ListTasksRequest in self._request_handlers: + capabilities.tasks.list = TasksListCapability() + if CancelTaskRequest in self._request_handlers: + capabilities.tasks.cancel = TasksCancelCapability() + + capabilities.tasks.requests = ServerTasksRequestsCapability( + tools=TasksToolsCapability() + ) # assuming always supported for now + + def enable_tasks( + self, + store: TaskStore | None = None, + queue: TaskMessageQueue | None = None, + ) -> TaskSupport: + """ + Enable experimental task support. + + This sets up the task infrastructure and auto-registers default handlers + for tasks/get, tasks/result, tasks/list, and tasks/cancel. + + Args: + store: Custom TaskStore implementation (defaults to InMemoryTaskStore) + queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue) + + Returns: + The TaskSupport configuration object + + Example: + # Simple in-memory setup + server.experimental.enable_tasks() + + # Custom store/queue for distributed systems + server.experimental.enable_tasks( + store=RedisTaskStore(redis_url), + queue=RedisTaskMessageQueue(redis_url), + ) + + WARNING: This API is experimental and may change without notice. + """ + if store is None: + store = InMemoryTaskStore() + if queue is None: + queue = InMemoryTaskMessageQueue() + + self._task_support = TaskSupport(store=store, queue=queue) + + # Auto-register default handlers + self._register_default_task_handlers() + + return self._task_support + + def _register_default_task_handlers(self) -> None: + """Register default handlers for task operations.""" + assert self._task_support is not None + support = self._task_support + + # Register get_task handler if not already registered + if GetTaskRequest not in self._request_handlers: + + async def _default_get_task(req: GetTaskRequest) -> ServerResult: + task = await support.store.get_task(req.params.taskId) + if task is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {req.params.taskId}", + ) + ) + return ServerResult( + GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + ) + + self._request_handlers[GetTaskRequest] = _default_get_task + + # Register get_task_result handler if not already registered + if GetTaskPayloadRequest not in self._request_handlers: + + async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: + ctx = self._server.request_context + result = await support.handler.handle(req, ctx.session, ctx.request_id) + return ServerResult(result) + + self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result + + # Register list_tasks handler if not already registered + if ListTasksRequest not in self._request_handlers: + + async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: + cursor = req.params.cursor if req.params else None + tasks, next_cursor = await support.store.list_tasks(cursor) + return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor)) + + self._request_handlers[ListTasksRequest] = _default_list_tasks + + # Register cancel_task handler if not already registered + if CancelTaskRequest not in self._request_handlers: + + async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult: + result = await cancel_task(support.store, req.params.taskId) + return ServerResult(result) + + self._request_handlers[CancelTaskRequest] = _default_cancel_task + + def list_tasks( + self, + ) -> Callable[ + [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], + Callable[[ListTasksRequest], Awaitable[ListTasksResult]], + ]: + """Register a handler for listing tasks. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], + ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: + logger.debug("Registering handler for ListTasksRequest") + wrapper = create_call_wrapper(func, ListTasksRequest) + + async def handler(req: ListTasksRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[ListTasksRequest] = handler + return func + + return decorator + + def get_task( + self, + ) -> Callable[ + [Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]] + ]: + """Register a handler for getting task status. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]], + ) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]: + logger.debug("Registering handler for GetTaskRequest") + wrapper = create_call_wrapper(func, GetTaskRequest) + + async def handler(req: GetTaskRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[GetTaskRequest] = handler + return func + + return decorator + + def get_task_result( + self, + ) -> Callable[ + [Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]], + Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], + ]: + """Register a handler for getting task results/payload. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]], + ) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]: + logger.debug("Registering handler for GetTaskPayloadRequest") + wrapper = create_call_wrapper(func, GetTaskPayloadRequest) + + async def handler(req: GetTaskPayloadRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[GetTaskPayloadRequest] = handler + return func + + return decorator + + def cancel_task( + self, + ) -> Callable[ + [Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]], + Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], + ]: + """Register a handler for cancelling tasks. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]], + ) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]: + logger.debug("Registering handler for CancelTaskRequest") + wrapper = create_call_wrapper(func, CancelTaskRequest) + + async def handler(req: CancelTaskRequest) -> ServerResult: + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[CancelTaskRequest] = handler + return func + + return decorator diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index f5a745db2f..6231aa8954 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -20,27 +20,27 @@ def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callab try: sig = inspect.signature(func) type_hints = get_type_hints(func) - except (ValueError, TypeError, NameError): + except (ValueError, TypeError, NameError): # pragma: no cover return lambda _: func() # Check for positional-only parameter typed as request_type for param_name, param in sig.parameters.items(): if param.kind == inspect.Parameter.POSITIONAL_ONLY: param_type = type_hints.get(param_name) - if param_type == request_type: + if param_type == request_type: # pragma: no branch # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: + if param.default is not inspect.Parameter.empty: # pragma: no cover return lambda _: func() # Found positional-only parameter with correct type and no default return lambda req: func(req) # Check for any positional/keyword parameter typed as request_type for param_name, param in sig.parameters.items(): - if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): # pragma: no branch param_type = type_hints.get(param_name) if param_type == request_type: # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: + if param.default is not inspect.Parameter.empty: # pragma: no cover return lambda _: func() # Found keyword parameter with correct type and no default diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 2fec3381bc..3fc2d497d1 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -67,6 +67,7 @@ async def main(): from __future__ import annotations as _annotations +import base64 import contextvars import json import logging @@ -82,14 +83,17 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.experimental.request_context import Experimental +from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext -from mcp.shared.exceptions import McpError +from mcp.shared.exceptions import McpError, UrlElicitationRequiredError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) @@ -154,6 +158,7 @@ def __init__( } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self._tool_cache: dict[str, types.Tool] = {} + self._experimental_handlers: ExperimentalHandlers | None = None logger.debug("Initializing server %r", name) def create_initialization_options( @@ -168,10 +173,10 @@ def pkg_version(package: str) -> str: from importlib.metadata import version return version(package) - except Exception: + except Exception: # pragma: no cover pass - return "unknown" + return "unknown" # pragma: no cover return InitializationOptions( server_name=self.name, @@ -212,14 +217,14 @@ def get_capabilities( tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed) # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: + if types.SetLevelRequest in self.request_handlers: # pragma: no cover logging_capability = types.LoggingCapability() # Set completions capabilities if handler exists if types.CompleteRequest in self.request_handlers: completions_capability = types.CompletionsCapability() - return types.ServerCapabilities( + capabilities = types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, tools=tools_capability, @@ -227,6 +232,9 @@ def get_capabilities( experimental=experimental_capabilities, completions=completions_capability, ) + if self._experimental_handlers: + self._experimental_handlers.update_capabilities(capabilities) + return capabilities @property def request_context( @@ -235,6 +243,18 @@ def request_context( """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() + @property + def experimental(self) -> ExperimentalHandlers: + """Experimental APIs for tasks and other features. + + WARNING: These APIs are experimental and may change without notice. + """ + + # We create this inline so we only add these capabilities _if_ they're actually used + if self._experimental_handlers is None: + self._experimental_handlers = ExperimentalHandlers(self, self.request_handlers, self.notification_handlers) + return self._experimental_handlers + def list_prompts(self): def decorator( func: Callable[[], Awaitable[list[types.Prompt]]] @@ -326,9 +346,7 @@ def create_content(data: str | bytes, mime_type: str | None): text=data, mimeType=mime_type or "text/plain", ) - case bytes() as data: - import base64 - + case bytes() as data: # pragma: no cover return types.BlobResourceContents( uri=req.params.uri, blob=base64.b64encode(data).decode(), @@ -336,7 +354,7 @@ def create_content(data: str | bytes, mime_type: str | None): ) match result: - case str() | bytes() as data: + case str() | bytes() as data: # pragma: no cover warnings.warn( "Returning str or bytes from read_resource is deprecated. " "Use Iterable[ReadResourceContents] instead.", @@ -353,10 +371,10 @@ def create_content(data: str | bytes, mime_type: str | None): contents=contents_list, ) ) - case _: + case _: # pragma: no cover raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - return types.ServerResult( + return types.ServerResult( # pragma: no cover types.ReadResourceResult( contents=[content], ) @@ -367,7 +385,7 @@ def create_content(data: str | bytes, mime_type: str | None): return decorator - def set_logging_level(self): + def set_logging_level(self): # pragma: no cover def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): logger.debug("Registering handler for SetLevelRequest") @@ -380,7 +398,7 @@ async def handler(req: types.SetLevelRequest): return decorator - def subscribe_resource(self): + def subscribe_resource(self): # pragma: no cover def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for SubscribeRequest") @@ -393,7 +411,7 @@ async def handler(req: types.SubscribeRequest): return decorator - def unsubscribe_resource(self): + def unsubscribe_resource(self): # pragma: no cover def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for UnsubscribeRequest") @@ -419,9 +437,10 @@ async def handler(req: types.ListToolsRequest): result = await wrapper(req) # Handle both old style (list[Tool]) and new style (ListToolsResult) - if isinstance(result, types.ListToolsResult): + if isinstance(result, types.ListToolsResult): # pragma: no cover # Refresh the tool cache with returned tools for tool in result.tools: + validate_and_warn_tool_name(tool.name) self._tool_cache[tool.name] = tool return types.ServerResult(result) else: @@ -429,6 +448,7 @@ async def handler(req: types.ListToolsRequest): # Clear and refresh the entire tool cache self._tool_cache.clear() for tool in result: + validate_and_warn_tool_name(tool.name) self._tool_cache[tool.name] = tool return types.ServerResult(types.ListToolsResult(tools=result)) @@ -480,7 +500,13 @@ def call_tool(self, *, validate_input: bool = True): def decorator( func: Callable[ ..., - Awaitable[UnstructuredContent | StructuredContent | CombinationContent], + Awaitable[ + UnstructuredContent + | StructuredContent + | CombinationContent + | types.CallToolResult + | types.CreateTaskResult + ], ], ): logger.debug("Registering handler for CallToolRequest") @@ -504,18 +530,23 @@ async def handler(req: types.CallToolRequest): # output normalization unstructured_content: UnstructuredContent maybe_structured_content: StructuredContent | None - if isinstance(results, tuple) and len(results) == 2: + if isinstance(results, types.CallToolResult): + return types.ServerResult(results) + elif isinstance(results, types.CreateTaskResult): + # Task-augmented execution returns task info instead of result + return types.ServerResult(results) + elif isinstance(results, tuple) and len(results) == 2: # tool returned both structured and unstructured content unstructured_content, maybe_structured_content = cast(CombinationContent, results) elif isinstance(results, dict): # tool returned structured content only maybe_structured_content = cast(StructuredContent, results) unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): + elif hasattr(results, "__iter__"): # pragma: no cover # tool returned unstructured content only unstructured_content = cast(UnstructuredContent, results) maybe_structured_content = None - else: + else: # pragma: no cover return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") # output validation @@ -538,6 +569,10 @@ async def handler(req: types.CallToolRequest): isError=False, ) ) + except UrlElicitationRequiredError: + # Re-raise UrlElicitationRequiredError so it can be properly handled + # by _handle_request, which converts it to an error response with code -32042 + raise except Exception as e: return self._make_error_result(str(e)) @@ -622,6 +657,12 @@ async def run( ) ) + # Configure task support for this session if enabled + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + if task_support is not None: + task_support.configure_session(session) + await stack.enter_async_context(task_support.run()) + async with anyio.create_task_group() as tg: async for message in session.incoming_messages: logger.debug("Received message: %s", message) @@ -642,67 +683,96 @@ async def _handle_message( raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] + match message: case RequestResponder(request=types.ClientRequest(root=req)) as responder: with responder: await self._handle_request(message, req, session, lifespan_context, raise_exceptions) case types.ClientNotification(root=notify): await self._handle_notification(notify) + case Exception(): # pragma: no cover + logger.error(f"Received exception from stream: {message}") + await session.send_log_message( + level="error", + data="Internal Server Error", + logger="mcp.server.exception_handler", + ) + if raise_exceptions: + raise message - for warning in w: + for warning in w: # pragma: no cover logger.info("Warning: %s: %s", warning.category.__name__, warning.message) async def _handle_request( self, message: RequestResponder[types.ClientRequest, types.ServerResult], - req: Any, + req: types.ClientRequestType, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): # type: ignore + + if handler := self.request_handlers.get(type(req)): logger.debug("Dispatching request of type %s", type(req).__name__) token = None try: - # Extract request context from message metadata + # Extract request context and close_sse_stream from message metadata request_data = None - if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): + close_sse_stream_cb = None + close_standalone_sse_stream_cb = None + if message.message_metadata is not None and isinstance( + message.message_metadata, ServerMessageMetadata + ): # pragma: no cover request_data = message.message_metadata.request_context + close_sse_stream_cb = message.message_metadata.close_sse_stream + close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream # Set our global state that can be retrieved via # app.get_request_context() + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + # Get task metadata from request params if present + task_metadata = None + if hasattr(req, "params") and req.params is not None: + task_metadata = getattr(req.params, "task", None) token = request_ctx.set( RequestContext( message.request_id, message.request_meta, session, lifespan_context, + Experimental( + task_metadata=task_metadata, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), request=request_data, + close_sse_stream=close_sse_stream_cb, + close_standalone_sse_stream=close_standalone_sse_stream_cb, ) ) response = await handler(req) - except McpError as err: + except McpError as err: # pragma: no cover response = err.error - except anyio.get_cancelled_exc_class(): + except anyio.get_cancelled_exc_class(): # pragma: no cover logger.info( "Request %s cancelled - duplicate response suppressed", message.request_id, ) return - except Exception as err: + except Exception as err: # pragma: no cover if raise_exceptions: raise err response = types.ErrorData(code=0, message=str(err), data=None) finally: # Reset the global state after we are done - if token is not None: + if token is not None: # pragma: no branch request_ctx.reset(token) await message.respond(response) - else: + else: # pragma: no cover await message.respond( types.ErrorData( code=types.METHOD_NOT_FOUND, @@ -718,7 +788,7 @@ async def _handle_notification(self, notify: Any): try: await handler(notify) - except Exception: + except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index d00277f11c..8f0baa3e9c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: """ from enum import Enum -from typing import Any, TypeVar +from typing import Any, TypeVar, overload import anyio import anyio.lowlevel @@ -46,7 +46,11 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from pydantic import AnyUrl import mcp.types as types +from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions +from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared.experimental.tasks.capabilities import check_tasks_capability +from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -79,6 +83,7 @@ class ServerSession( ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None + _experimental_features: ExperimentalServerSessionFeatures | None = None def __init__( self, @@ -100,17 +105,25 @@ def __init__( @property def client_params(self) -> types.InitializeRequestParams | None: - return self._client_params + return self._client_params # pragma: no cover - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: + @property + def experimental(self) -> ExperimentalServerSessionFeatures: + """Experimental APIs for server→client task operations. + + WARNING: These APIs are experimental and may change without notice. + """ + if self._experimental_features is None: + self._experimental_features = ExperimentalServerSessionFeatures(self) + return self._experimental_features + + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover """Check if the client supports a specific capability.""" if self._client_params is None: return False - # Get client capabilities from initialization params client_caps = self._client_params.capabilities - # Check each specified capability in the passed in capability object if capability.roots is not None: if client_caps.roots is None: return False @@ -120,19 +133,27 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if capability.sampling is not None: if client_caps.sampling is None: return False - - if capability.elicitation is not None: - if client_caps.elicitation is None: + if capability.sampling.context is not None and client_caps.sampling.context is None: return False + if capability.sampling.tools is not None and client_caps.sampling.tools is None: + return False + + if capability.elicitation is not None and client_caps.elicitation is None: + return False if capability.experimental is not None: if client_caps.experimental is None: return False - # Check each experimental capability for exp_key, exp_value in capability.experimental.items(): if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False + if capability.tasks is not None: + if client_caps.tasks is None: + return False + if not check_tasks_capability(capability.tasks, client_caps.tasks): + return False + return True async def _receive_loop(self) -> None: @@ -163,6 +184,7 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques ) ) ) + self._initialization_state = InitializationState.Initialized case types.PingRequest(): # Ping requests are allowed at any time pass @@ -177,7 +199,7 @@ async def _received_notification(self, notification: types.ClientNotification) - case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: - if self._initialization_state != InitializationState.Initialized: + if self._initialization_state != InitializationState.Initialized: # pragma: no cover raise RuntimeError("Received notification before initialization was complete") async def send_log_message( @@ -201,7 +223,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: AnyUrl) -> None: + async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover """Send a resource updated notification.""" await self.send_notification( types.ServerNotification( @@ -211,6 +233,7 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: ) ) + @overload async def create_message( self, messages: list[types.SamplingMessage], @@ -222,28 +245,106 @@ async def create_message( stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, model_preferences: types.ModelPreferences | None = None, + tools: None = None, + tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, ) -> types.CreateMessageResult: - """Send a sampling/create_message request.""" + """Overload: Without tools, returns single content.""" + ... + + @overload + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool], + tool_choice: types.ToolChoice | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResultWithTools: + """Overload: With tools, returns array-capable content.""" + ... + + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool] | None = None, + tool_choice: types.ToolChoice | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: + """Send a sampling/create_message request. + + Args: + messages: The conversation messages to send. + max_tokens: Maximum number of tokens to generate. + system_prompt: Optional system prompt. + include_context: Optional context inclusion setting. + Should only be set to "thisServer" or "allServers" + if the client has sampling.context capability. + temperature: Optional sampling temperature. + stop_sequences: Optional stop sequences. + metadata: Optional metadata to pass through to the LLM provider. + model_preferences: Optional model selection preferences. + tools: Optional list of tools the LLM can use during sampling. + Requires client to have sampling.tools capability. + tool_choice: Optional control over tool usage behavior. + Requires client to have sampling.tools capability. + related_request_id: Optional ID of a related request. + + Returns: + The sampling result from the client. + + Raises: + McpError: If tools are provided but client doesn't support them. + ValueError: If tool_use or tool_result message structure is invalid. + """ + client_caps = self._client_params.capabilities if self._client_params else None + validate_sampling_tools(client_caps, tools, tool_choice) + validate_tool_use_result_messages(messages) + + request = types.ServerRequest( + types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + tools=tools, + toolChoice=tool_choice, + ), + ) + ) + metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) + + # Use different result types based on whether tools are provided + if tools is not None: + return await self.send_request( + request=request, + result_type=types.CreateMessageResultWithTools, + metadata=metadata_obj, + ) return await self.send_request( - request=types.ServerRequest( - types.CreateMessageRequest( - params=types.CreateMessageRequestParams( - messages=messages, - systemPrompt=system_prompt, - includeContext=include_context, - temperature=temperature, - maxTokens=max_tokens, - stopSequences=stop_sequences, - metadata=metadata, - modelPreferences=model_preferences, - ), - ) - ), + request=request, result_type=types.CreateMessageResult, - metadata=ServerMessageMetadata( - related_request_id=related_request_id, - ), + metadata=metadata_obj, ) async def list_roots(self) -> types.ListRootsResult: @@ -259,19 +360,42 @@ async def elicit( requestedSchema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, ) -> types.ElicitResult: - """Send an elicitation/create request. + """Send a form mode elicitation/create request. Args: message: The message to present to the user requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation Returns: The client's response + + Note: + This method is deprecated in favor of elicit_form(). It remains for + backward compatibility but new code should use elicit_form(). + """ + return await self.elicit_form(message, requestedSchema, related_request_id) + + async def elicit_form( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a form mode elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response with form data """ return await self.send_request( types.ServerRequest( types.ElicitRequest( - params=types.ElicitRequestParams( + params=types.ElicitRequestFormParams( message=message, requestedSchema=requestedSchema, ), @@ -281,7 +405,42 @@ async def elicit( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: + async def elicit_url( + self, + message: str, + url: str, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send a URL mode elicitation/create request. + + This directs the user to an external URL for out-of-band interactions + like OAuth flows, credential collection, or payment processing. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_request_id: Optional ID of the request that triggered this elicitation + + Returns: + The client's response indicating acceptance, decline, or cancellation + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + params=types.ElicitRequestURLParams( + message=message, + url=url, + elicitationId=elicitation_id, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + + async def send_ping(self) -> types.EmptyResult: # pragma: no cover """Send a ping request.""" return await self.send_request( types.ServerRequest(types.PingRequest()), @@ -311,18 +470,217 @@ async def send_progress_notification( related_request_id, ) - async def send_resource_list_changed(self) -> None: + async def send_resource_list_changed(self) -> None: # pragma: no cover """Send a resource list changed notification.""" await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification())) - async def send_tool_list_changed(self) -> None: + async def send_tool_list_changed(self) -> None: # pragma: no cover """Send a tool list changed notification.""" await self.send_notification(types.ServerNotification(types.ToolListChangedNotification())) - async def send_prompt_list_changed(self) -> None: + async def send_prompt_list_changed(self) -> None: # pragma: no cover """Send a prompt list changed notification.""" await self.send_notification(types.ServerNotification(types.PromptListChangedNotification())) + async def send_elicit_complete( + self, + elicitation_id: str, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send an elicitation completion notification. + + This should be sent when a URL mode elicitation has been completed + out-of-band to inform the client that it may retry any requests + that were waiting for this elicitation. + + Args: + elicitation_id: The unique identifier of the completed elicitation + related_request_id: Optional ID of the request that triggered this + """ + await self.send_notification( + types.ServerNotification( + types.ElicitCompleteNotification( + params=types.ElicitCompleteNotificationParams(elicitationId=elicitation_id) + ) + ), + related_request_id, + ) + + def _build_elicit_form_request( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_task_id: str | None = None, + task: types.TaskMetadata | None = None, + ) -> types.JSONRPCRequest: + """Build a form mode elicitation request without sending it. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata + task: If provided, makes this a task-augmented request + + Returns: + A JSONRPCRequest ready to be sent or queued + """ + params = types.ElicitRequestFormParams( + message=message, + requestedSchema=requestedSchema, + task=task, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata if associated with a parent task + if related_task_id is not None: + # Defensive: model_dump() never includes _meta, but guard against future changes + if "_meta" not in params_data: # pragma: no cover + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( + taskId=related_task_id + ).model_dump(by_alias=True) + + request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id + if related_task_id is None: + self._request_id += 1 + + return types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="elicitation/create", + params=params_data, + ) + + def _build_elicit_url_request( + self, + message: str, + url: str, + elicitation_id: str, + related_task_id: str | None = None, + ) -> types.JSONRPCRequest: + """Build a URL mode elicitation request without sending it. + + Args: + message: Human-readable explanation of why the interaction is needed + url: The URL the user should navigate to + elicitation_id: Unique identifier for tracking this elicitation + related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata + + Returns: + A JSONRPCRequest ready to be sent or queued + """ + params = types.ElicitRequestURLParams( + message=message, + url=url, + elicitationId=elicitation_id, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata if associated with a parent task + if related_task_id is not None: + # Defensive: model_dump() never includes _meta, but guard against future changes + if "_meta" not in params_data: # pragma: no cover + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( + taskId=related_task_id + ).model_dump(by_alias=True) + + request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id + if related_task_id is None: + self._request_id += 1 + + return types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="elicitation/create", + params=params_data, + ) + + def _build_create_message_request( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + tools: list[types.Tool] | None = None, + tool_choice: types.ToolChoice | None = None, + related_task_id: str | None = None, + task: types.TaskMetadata | None = None, + ) -> types.JSONRPCRequest: + """Build a sampling/createMessage request without sending it. + + Args: + messages: The conversation messages to send + max_tokens: Maximum number of tokens to generate + system_prompt: Optional system prompt + include_context: Optional context inclusion setting + temperature: Optional sampling temperature + stop_sequences: Optional stop sequences + metadata: Optional metadata to pass through to the LLM provider + model_preferences: Optional model selection preferences + tools: Optional list of tools the LLM can use during sampling + tool_choice: Optional control over tool usage behavior + related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata + task: If provided, makes this a task-augmented request + + Returns: + A JSONRPCRequest ready to be sent or queued + """ + params = types.CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + tools=tools, + toolChoice=tool_choice, + task=task, + ) + params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True) + + # Add related-task metadata if associated with a parent task + if related_task_id is not None: + # Defensive: model_dump() never includes _meta, but guard against future changes + if "_meta" not in params_data: # pragma: no cover + params_data["_meta"] = {} + params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata( + taskId=related_task_id + ).model_dump(by_alias=True) + + request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id + if related_task_id is None: + self._request_id += 1 + + return types.JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + method="sampling/createMessage", + params=params_data, + ) + + async def send_message(self, message: SessionMessage) -> None: + """Send a raw session message. + + This is primarily used by TaskResultHandler to deliver queued messages + (elicitation/sampling requests) to the client during task execution. + + WARNING: This is a low-level experimental method that may change without + notice. Prefer using higher-level methods like send_notification() or + send_request() for normal operations. + + Args: + message: The session message to send + """ + await self._write_stream.send(message) + async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index b7ff332803..19af93fd16 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -119,7 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -198,7 +198,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover logger.debug("Handling POST message") request = Request(scope, receive) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index d1618a3712..bcb9247abb 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -63,13 +63,13 @@ async def stdin_reader(): async for line in stdin: try: message = types.JSONRPCMessage.model_validate_json(line) - except Exception as exc: + except Exception as exc: # pragma: no cover await read_stream_writer.send(exc) continue session_message = SessionMessage(message) await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() async def stdout_writer(): @@ -79,7 +79,7 @@ async def stdout_writer(): json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await stdout.write(json + "\n") await stdout.flush() - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index b45d742b00..2613b530c4 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -15,6 +15,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from http import HTTPStatus +from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -87,18 +88,18 @@ class EventStore(ABC): """ @abstractmethod - async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: """ Stores an event for later retrieval. Args: stream_id: ID of the stream the event belongs to - message: The JSON-RPC message to store + message: The JSON-RPC message to store, or None for priming events Returns: The generated event ID for the stored event """ - pass + pass # pragma: no cover @abstractmethod async def replay_events_after( @@ -116,7 +117,7 @@ async def replay_events_after( Returns: The stream ID of the replayed events """ - pass + pass # pragma: no cover class StreamableHTTPServerTransport: @@ -140,6 +141,7 @@ def __init__( is_json_response_enabled: bool = False, event_store: EventStore | None = None, security_settings: TransportSecuritySettings | None = None, + retry_interval: int | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -153,6 +155,10 @@ def __init__( resumability will be enabled, allowing clients to reconnect and resume messages. security_settings: Optional security settings for DNS rebinding protection. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE + retry field. When set, the server will send a retry field in + SSE priming events to control client reconnection timing for + polling behavior. Only used when event_store is provided. Raises: ValueError: If the session ID contains invalid characters. @@ -164,6 +170,7 @@ def __init__( self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._security = TransportSecurityMiddleware(security_settings) + self._retry_interval = retry_interval self._request_streams: dict[ RequestId, tuple[ @@ -171,6 +178,7 @@ def __init__( MemoryObjectReceiveStream[EventMessage], ], ] = {} + self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False @property @@ -178,6 +186,111 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated + def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + """Close SSE connection for a specific request without terminating the stream. + + This method closes the HTTP connection for the specified request, triggering + client reconnection. Events continue to be stored in the event store and will + be replayed when the client reconnects with Last-Event-ID. + + Use this to implement polling behavior during long-running operations - + client will reconnect after the retry interval specified in the priming event. + + Args: + request_id: The request ID whose SSE stream should be closed. + + Note: + This is a no-op if there is no active stream for the request ID. + Requires event_store to be configured for events to be stored during + the disconnect. + """ + writer = self._sse_stream_writers.pop(request_id, None) + if writer: + writer.close() + + # Also close and remove request streams + if request_id in self._request_streams: + send_stream, receive_stream = self._request_streams.pop(request_id) + send_stream.close() + receive_stream.close() + + def close_standalone_sse_stream(self) -> None: # pragma: no cover + """Close the standalone GET SSE stream, triggering client reconnection. + + This method closes the HTTP connection for the standalone GET stream used + for unsolicited server-to-client notifications. The client SHOULD reconnect + with Last-Event-ID to resume receiving notifications. + + Use this to implement polling behavior for the notification stream - + client will reconnect after the retry interval specified in the priming event. + + Note: + This is a no-op if there is no active standalone SSE stream. + Requires event_store to be configured for events to be stored during + the disconnect. + Currently, client reconnection for standalone GET streams is NOT + implemented - this is a known gap (see test_standalone_get_stream_reconnection). + """ + self.close_sse_stream(GET_STREAM_KEY) + + def _create_session_message( # pragma: no cover + self, + message: JSONRPCMessage, + request: Request, + request_id: RequestId, + protocol_version: str, + ) -> SessionMessage: + """Create a session message with metadata including close_sse_stream callback. + + The close_sse_stream callbacks are only provided when the client supports + resumability (protocol version >= 2025-11-25). Old clients can't resume if + the stream is closed early because they didn't receive a priming event. + """ + # Only provide close callbacks when client supports resumability + if self._event_store and protocol_version >= "2025-11-25": + + async def close_stream_callback() -> None: + self.close_sse_stream(request_id) + + async def close_standalone_stream_callback() -> None: + self.close_standalone_sse_stream() + + metadata = ServerMessageMetadata( + request_context=request, + close_sse_stream=close_stream_callback, + close_standalone_sse_stream=close_standalone_stream_callback, + ) + else: + metadata = ServerMessageMetadata(request_context=request) + + return SessionMessage(message, metadata=metadata) + + async def _maybe_send_priming_event( + self, + request_id: RequestId, + sse_stream_writer: MemoryObjectSendStream[dict[str, Any]], + protocol_version: str, + ) -> None: + """Send priming event for SSE resumability if event_store is configured. + + Only sends priming events to clients with protocol version >= 2025-11-25, + which includes the fix for handling empty SSE data. Older clients would + crash trying to parse empty data as JSON. + """ + if not self._event_store: + return + # Priming events have empty data which older clients cannot handle. + if protocol_version < "2025-11-25": + return + priming_event_id = await self._event_store.store_event( + str(request_id), # Convert RequestId to StreamId (str) + None, # Priming event has no payload + ) + priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""} + if self._retry_interval is not None: + priming_event["retry"] = self._retry_interval + await sse_stream_writer.send(priming_event) + def _create_error_response( self, error_message: str, @@ -187,7 +300,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: + if headers: # pragma: no cover response_headers.update(headers) if self.mcp_session_id: @@ -209,7 +322,7 @@ def _create_error_response( headers=response_headers, ) - def _create_json_response( + def _create_json_response( # pragma: no cover self, response_message: JSONRPCMessage | None, status_code: HTTPStatus = HTTPStatus.OK, @@ -229,11 +342,11 @@ def _create_json_response( headers=response_headers, ) - def _get_session_id(self, request: Request) -> str | None: + def _get_session_id(self, request: Request) -> str | None: # pragma: no cover """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) - def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", @@ -246,7 +359,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: return event_data - async def _clean_up_memory_streams(self, request_id: RequestId) -> None: + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # pragma: no cover """Clean up memory streams for a given request ID.""" if request_id in self._request_streams: try: @@ -267,11 +380,11 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No # Validate request headers for DNS rebinding protection is_post = request.method == "POST" error_response = await self._security.validate_request(request, is_post=is_post) - if error_response: + if error_response: # pragma: no cover await error_response(scope, receive, send) return - if self._terminated: + if self._terminated: # pragma: no cover # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -282,11 +395,11 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No if request.method == "POST": await self._handle_post_request(scope, request, receive, send) - elif request.method == "GET": + elif request.method == "GET": # pragma: no cover await self._handle_get_request(request, send) - elif request.method == "DELETE": + elif request.method == "DELETE": # pragma: no cover await self._handle_delete_request(request, send) - else: + else: # pragma: no cover await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -306,24 +419,40 @@ def _check_content_type(self, request: Request) -> bool: return any(part == CONTENT_TYPE_JSON for part in content_type_parts) + async def _validate_accept_header(self, request: Request, scope: Scope, send: Send) -> bool: # pragma: no cover + """Validate Accept header based on response mode. Returns True if valid.""" + has_json, has_sse = self._check_accept_headers(request) + if self.is_json_response_enabled: + # For JSON-only responses, only require application/json + if not has_json: + response = self._create_error_response( + "Not Acceptable: Client must accept application/json", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, request.receive, send) + return False + # For SSE responses, require both content types + elif not (has_json and has_sse): + response = self._create_error_response( + "Not Acceptable: Client must accept both application/json and text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, request.receive, send) + return False + return True + async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer - if writer is None: + if writer is None: # pragma: no cover raise ValueError("No read stream writer available. Ensure connect() is called first.") try: - # Check Accept headers - has_json, has_sse = self._check_accept_headers(request) - if not (has_json and has_sse): - response = self._create_error_response( - ("Not Acceptable: Client must accept both application/json and text/event-stream"), - HTTPStatus.NOT_ACCEPTABLE, - ) - await response(scope, receive, send) + # Validate Accept header + if not await self._validate_accept_header(request, scope, send): return # Validate Content-Type - if not self._check_content_type(request): + if not self._check_content_type(request): # pragma: no cover response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", HTTPStatus.UNSUPPORTED_MEDIA_TYPE, @@ -341,9 +470,9 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) return - try: + try: # pragma: no cover message = JSONRPCMessage.model_validate(raw_message) - except ValidationError as e: + except ValidationError as e: # pragma: no cover response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -353,9 +482,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # Check if this is an initialization request - is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + is_initialization_request = ( + isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + ) # pragma: no cover - if is_initialization_request: + if is_initialization_request: # pragma: no cover # Check if the server already has an established session if self.mcp_session_id: # Check if request has a session ID @@ -369,11 +500,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): + elif not await self._validate_request_headers(request, send): # pragma: no cover return # For notifications and responses only, return 202 Accepted - if not isinstance(message.root, JSONRPCRequest): + if not isinstance(message.root, JSONRPCRequest): # pragma: no cover # Create response object and send it response = self._create_json_response( None, @@ -388,13 +519,22 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return + # Extract protocol version for priming event decision. + # For initialize requests, get from request params. + # For other requests, get from header (already validated). + protocol_version = ( + str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) + if is_initialization_request and message.root.params + else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) + ) + # Extract the request ID outside the try block for proper scope - request_id = str(message.root.id) + request_id = str(message.root.id) # pragma: no cover # Register this stream for the request ID - self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) - request_stream_reader = self._request_streams[request_id][1] + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover + request_stream_reader = self._request_streams[request_id][1] # pragma: no cover - if self.is_json_response_enabled: + if self.is_json_response_enabled: # pragma: no cover # Process the message metadata = ServerMessageMetadata(request_context=request) session_message = SessionMessage(message, metadata=metadata) @@ -437,14 +577,20 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) finally: await self._clean_up_memory_streams(request_id) - else: + else: # pragma: no cover # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + # Store writer reference so close_sse_stream() can close it + self._sse_stream_writers[request_id] = sse_stream_writer + async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: + # Send priming event for SSE resumability + await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version) + # Process messages from the request-specific stream async for event_message in request_stream_reader: # Build the event data @@ -457,10 +603,14 @@ async def sse_writer(): JSONRPCResponse | JSONRPCError, ): break + except anyio.ClosedResourceError: + # Expected when close_sse_stream() is called + logger.debug("SSE stream closed by close_sse_stream()") except Exception: logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") + self._sse_stream_writers.pop(request_id, None) await self._clean_up_memory_streams(request_id) # Create and start EventSourceResponse @@ -484,8 +634,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - metadata = ServerMessageMetadata(request_context=request) - session_message = SessionMessage(message, metadata=metadata) + session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) except Exception: logger.exception("SSE response error") @@ -493,7 +642,7 @@ async def sse_writer(): await sse_stream_reader.aclose() await self._clean_up_memory_streams(request_id) - except Exception as err: + except Exception as err: # pragma: no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -505,7 +654,7 @@ async def sse_writer(): await writer.send(Exception(err)) return - async def _handle_get_request(self, request: Request, send: Send) -> None: + async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover """ Handle GET request to establish SSE. @@ -597,7 +746,7 @@ async def standalone_sse_writer(): await sse_stream_reader.aclose() await self._clean_up_memory_streams(GET_STREAM_KEY) - async def _handle_delete_request(self, request: Request, send: Send) -> None: + async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover """Handle DELETE requests for explicit session termination.""" # Validate session ID if not self.mcp_session_id: @@ -633,25 +782,25 @@ async def terminate(self) -> None: request_stream_keys = list(self._request_streams.keys()) # Close all request streams asynchronously - for key in request_stream_keys: + for key in request_stream_keys: # pragma: no cover await self._clean_up_memory_streams(key) # Clear the request streams dictionary immediately self._request_streams.clear() try: - if self._read_stream_writer is not None: + if self._read_stream_writer is not None: # pragma: no branch await self._read_stream_writer.aclose() - if self._read_stream is not None: + if self._read_stream is not None: # pragma: no branch await self._read_stream.aclose() - if self._write_stream_reader is not None: + if self._write_stream_reader is not None: # pragma: no branch await self._write_stream_reader.aclose() - if self._write_stream is not None: + if self._write_stream is not None: # pragma: no branch await self._write_stream.aclose() - except Exception as e: + except Exception as e: # pragma: no cover # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, @@ -667,14 +816,14 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: + async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): return False return True - async def _validate_session(self, request: Request, send: Send) -> bool: + async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover """Validate the session ID in the request.""" if not self.mcp_session_id: # If we're not using session IDs, return True @@ -703,7 +852,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True - async def _validate_protocol_version(self, request: Request, send: Send) -> bool: + async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover """Validate the protocol version header in the request.""" # Get the protocol version from the request headers protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) @@ -725,7 +874,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover """ Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. @@ -744,6 +893,9 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) if self.mcp_session_id: headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + # Get protocol version from header (already validated in _validate_protocol_version) + replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) + # Create SSE stream for replay sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) @@ -760,6 +912,13 @@ async def send_event(event_message: EventMessage) -> None: # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: + # Register SSE writer so close_sse_stream() can close it + self._sse_stream_writers[stream_id] = sse_stream_writer + + # Send priming event for this new connection + await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version) + + # Create new request streams for this connection self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) msg_reader = self._request_streams[stream_id][1] @@ -769,6 +928,9 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) + except anyio.ClosedResourceError: + # Expected when close_sse_stream() is called + logger.debug("Replay SSE stream closed by close_sse_stream()") except Exception: logger.exception("Error in replay sender") @@ -826,7 +988,7 @@ async def connect( # Start a task group for message routing async with anyio.create_task_group() as tg: # Create a message router that distributes messages to request streams - async def message_router(): + async def message_router(): # pragma: no cover try: async for session_message in write_stream_reader: # Determine which request stream(s) should receive this message @@ -870,11 +1032,16 @@ async def message_router(): # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) else: - logging.debug( + logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) + except anyio.ClosedResourceError: + if self._terminated: + logger.debug("Read stream closed by client") + else: + logger.exception("Unexpected closure of read stream in message router") except Exception: logger.exception("Error in message router") @@ -885,7 +1052,7 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): + for stream_id in list(self._request_streams.keys()): # pragma: no cover await self._clean_up_memory_streams(stream_id) self._request_streams.clear() @@ -895,6 +1062,6 @@ async def message_router(): await read_stream.aclose() await write_stream_reader.aclose() await write_stream.aclose() - except Exception as e: + except Exception as e: # pragma: no cover # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 53d542d21b..50d2aefa29 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -51,6 +51,9 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. + security_settings: Optional transport security settings. + retry_interval: Retry interval in milliseconds to suggest to clients in SSE + retry field. Used for SSE polling behavior. """ def __init__( @@ -60,12 +63,14 @@ def __init__( json_response: bool = False, stateless: bool = False, security_settings: TransportSecuritySettings | None = None, + retry_interval: int | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless self.security_settings = security_settings + self.retry_interval = retry_interval # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -178,7 +183,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: + except Exception: # pragma: no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking @@ -210,7 +215,7 @@ async def _handle_stateful_request( request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) # Existing session case - if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") await transport.handle_request(scope, receive, send) @@ -226,6 +231,7 @@ async def _handle_stateful_request( is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) security_settings=self.security_settings, + retry_interval=self.retry_interval, ) assert http_transport.mcp_session_id is not None @@ -251,7 +257,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE ) finally: # Only remove from instances if not terminated - if ( + if ( # pragma: no branch http_transport.mcp_session_id and http_transport.mcp_session_id in self._server_instances and not http_transport.is_terminated @@ -270,7 +276,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) - else: + else: # pragma: no cover # Invalid session ID response = Response( "Bad Request: No valid session ID provided", diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py deleted file mode 100644 index a74751312c..0000000000 --- a/src/mcp/server/streaming_asgi_transport.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -A modified version of httpx.ASGITransport that supports streaming responses. - -This transport runs the ASGI app as a separate anyio task, allowing it to -handle streaming responses like SSE where the app doesn't terminate until -the connection is closed. - -This is only intended for writing tests for the SSE transport. -""" - -import typing -from typing import Any, cast - -import anyio -import anyio.abc -import anyio.streams.memory -from httpx._models import Request, Response -from httpx._transports.base import AsyncBaseTransport -from httpx._types import AsyncByteStream -from starlette.types import ASGIApp, Receive, Scope, Send - - -class StreamingASGITransport(AsyncBaseTransport): - """ - A custom AsyncTransport that handles sending requests directly to an ASGI app - and supports streaming responses like SSE. - - Unlike the standard ASGITransport, this transport runs the ASGI app in a - separate anyio task, allowing it to handle responses from apps that don't - terminate immediately (like SSE endpoints). - - Arguments: - - * `app` - The ASGI application. - * `raise_app_exceptions` - Boolean indicating if exceptions in the application - should be raised. Default to `True`. Can be set to `False` for use cases - such as testing the content of a client 500 response. - * `root_path` - The root path on which the ASGI application should be mounted. - * `client` - A two-tuple indicating the client IP and port of incoming requests. - * `response_timeout` - Timeout in seconds to wait for the initial response. - Default is 10 seconds. - - TODO: https://github.com/encode/httpx/pull/3059 is adding something similar to - upstream httpx. When that merges, we should delete this & switch back to the - upstream implementation. - """ - - def __init__( - self, - app: ASGIApp, - task_group: anyio.abc.TaskGroup, - raise_app_exceptions: bool = True, - root_path: str = "", - client: tuple[str, int] = ("127.0.0.1", 123), - ) -> None: - self.app = app - self.raise_app_exceptions = raise_app_exceptions - self.root_path = root_path - self.client = client - self.task_group = task_group - - async def handle_async_request( - self, - request: Request, - ) -> Response: - assert isinstance(request.stream, AsyncByteStream) - - # ASGI scope. - scope = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": request.method, - "headers": [(k.lower(), v) for (k, v) in request.headers.raw], - "scheme": request.url.scheme, - "path": request.url.path, - "raw_path": request.url.raw_path.split(b"?")[0], - "query_string": request.url.query, - "server": (request.url.host, request.url.port), - "client": self.client, - "root_path": self.root_path, - } - - # Request body - request_body_chunks = request.stream.__aiter__() - request_complete = False - - # Response state - status_code = 499 - response_headers = None - response_started = False - response_complete = anyio.Event() - initial_response_ready = anyio.Event() - - # Synchronization for streaming response - asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100) - content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) - - # ASGI callables. - async def receive() -> dict[str, Any]: - nonlocal request_complete - - if request_complete: - await response_complete.wait() - return {"type": "http.disconnect"} - - try: - body = await request_body_chunks.__anext__() - except StopAsyncIteration: - request_complete = True - return {"type": "http.request", "body": b"", "more_body": False} - return {"type": "http.request", "body": body, "more_body": True} - - async def send(message: dict[str, Any]) -> None: - nonlocal status_code, response_headers, response_started - - await asgi_send_channel.send(message) - - # Start the ASGI application in a separate task - async def run_app() -> None: - try: - # Cast the receive and send functions to the ASGI types - await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send)) - except Exception: - if self.raise_app_exceptions: - raise - - if not response_started: - await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []}) - - await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False}) - finally: - await asgi_send_channel.aclose() - - # Process messages from the ASGI app - async def process_messages() -> None: - nonlocal status_code, response_headers, response_started - - try: - async with asgi_receive_channel: - async for message in asgi_receive_channel: - if message["type"] == "http.response.start": - assert not response_started - status_code = message["status"] - response_headers = message.get("headers", []) - response_started = True - - # As soon as we have headers, we can return a response - initial_response_ready.set() - - elif message["type"] == "http.response.body": - body = message.get("body", b"") - more_body = message.get("more_body", False) - - if body and request.method != "HEAD": - await content_send_channel.send(body) - - if not more_body: - response_complete.set() - await content_send_channel.aclose() - break - finally: - # Ensure events are set even if there's an error - initial_response_ready.set() - response_complete.set() - await content_send_channel.aclose() - - # Create tasks for running the app and processing messages - self.task_group.start_soon(run_app) - self.task_group.start_soon(process_messages) - - # Wait for the initial response or timeout - await initial_response_ready.wait() - - # Create a streaming response - return Response( - status_code, - headers=response_headers, - stream=StreamingASGIResponseStream(content_receive_channel), - ) - - -class StreamingASGIResponseStream(AsyncByteStream): - """ - A modified ASGIResponseStream that supports streaming responses. - - This class extends the standard ASGIResponseStream to handle cases where - the response body continues to be generated after the initial response - is returned. - """ - - def __init__( - self, - receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], - ) -> None: - self.receive_channel = receive_channel - - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - try: - async for chunk in self.receive_channel: - yield chunk - finally: - await self.receive_channel.aclose() diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index de4542af65..ee1e4505a7 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -42,7 +42,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: + def _validate_host(self, host: str | None) -> bool: # pragma: no cover """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -64,7 +64,7 @@ def _validate_host(self, host: str | None) -> bool: logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: + def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -86,7 +86,7 @@ def _validate_origin(self, origin: str | None) -> bool: logger.warning(f"Invalid Origin header: {origin}") return False - def _validate_content_type(self, content_type: str | None) -> bool: + def _validate_content_type(self, content_type: str | None) -> bool: # pragma: no cover """Validate the Content-Type header for POST requests.""" if not content_type: logger.warning("Missing Content-Type header in POST request") @@ -105,23 +105,23 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: + if is_post: # pragma: no branch content_type = request.headers.get("content-type") - if not self._validate_content_type(content_type): + if not self._validate_content_type(content_type): # pragma: no cover return Response("Invalid Content-Type header", status_code=400) # Skip remaining validation if DNS rebinding protection is disabled if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header - host = request.headers.get("host") - if not self._validate_host(host): - return Response("Invalid Host header", status_code=421) + # Validate Host header # pragma: no cover + host = request.headers.get("host") # pragma: no cover + if not self._validate_host(host): # pragma: no cover + return Response("Invalid Host header", status_code=421) # pragma: no cover - # Validate Origin header - origin = request.headers.get("origin") - if not self._validate_origin(origin): - return Response("Invalid Origin header", status_code=403) + # Validate Origin header # pragma: no cover + origin = request.headers.get("origin") # pragma: no cover + if not self._validate_origin(origin): # pragma: no cover + return Response("Invalid Origin header", status_code=403) # pragma: no cover - return None + return None # pragma: no cover diff --git a/src/mcp/server/validation.py b/src/mcp/server/validation.py new file mode 100644 index 0000000000..2ccd7056bd --- /dev/null +++ b/src/mcp/server/validation.py @@ -0,0 +1,104 @@ +""" +Shared validation functions for server requests. + +This module provides validation logic for sampling and elicitation requests +that is shared across normal and task-augmented code paths. +""" + +from mcp.shared.exceptions import McpError +from mcp.types import ( + INVALID_PARAMS, + ClientCapabilities, + ErrorData, + SamplingMessage, + Tool, + ToolChoice, +) + + +def check_sampling_tools_capability(client_caps: ClientCapabilities | None) -> bool: + """ + Check if the client supports sampling tools capability. + + Args: + client_caps: The client's declared capabilities + + Returns: + True if client supports sampling.tools, False otherwise + """ + if client_caps is None: + return False + if client_caps.sampling is None: + return False + if client_caps.sampling.tools is None: + return False + return True + + +def validate_sampling_tools( + client_caps: ClientCapabilities | None, + tools: list[Tool] | None, + tool_choice: ToolChoice | None, +) -> None: + """ + Validate that the client supports sampling tools if tools are being used. + + Args: + client_caps: The client's declared capabilities + tools: The tools list, if provided + tool_choice: The tool choice setting, if provided + + Raises: + McpError: If tools/tool_choice are provided but client doesn't support them + """ + if tools is not None or tool_choice is not None: + if not check_sampling_tools_capability(client_caps): + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="Client does not support sampling tools capability", + ) + ) + + +def validate_tool_use_result_messages(messages: list[SamplingMessage]) -> None: + """ + Validate tool_use/tool_result message structure per SEP-1577. + + This validation ensures: + 1. Messages with tool_result content contain ONLY tool_result content + 2. tool_result messages are preceded by a message with tool_use + 3. tool_result IDs match the tool_use IDs from the previous message + + See: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577 + + Args: + messages: The list of sampling messages to validate + + Raises: + ValueError: If the message structure is invalid + """ + if not messages: + return + + last_content = messages[-1].content_as_list + has_tool_results = any(c.type == "tool_result" for c in last_content) + + previous_content = messages[-2].content_as_list if len(messages) >= 2 else None + has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content) + + if has_tool_results: + # Per spec: "SamplingMessage with tool result content blocks + # MUST NOT contain other content types." + if any(c.type != "tool_result" for c in last_content): + raise ValueError("The last message must contain only tool_result content if any is present") + if previous_content is None: + raise ValueError("tool_result requires a previous message containing tool_use") + if not has_previous_tool_use: + raise ValueError("tool_result blocks do not match any tool_use in the previous message") + + if has_previous_tool_use and previous_content: + tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"} + tool_result_ids = {c.toolUseId for c in last_content if c.type == "tool_result"} + if tool_use_ids != tool_result_ids: + raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match") diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 7c0d8789cb..5d5efd16e9 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -@asynccontextmanager +@asynccontextmanager # pragma: no cover async def websocket_server(scope: Scope, receive: Receive, send: Send): """ WebSocket server transport for MCP. This is an ASGI application, suitable to be diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index e0611ce73d..945ef80955 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -4,11 +4,15 @@ import httpx -__all__ = ["create_mcp_http_client"] +__all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"] +# Default MCP timeout configuration +MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds) +MCP_DEFAULT_SSE_READ_TIMEOUT = 300.0 # SSE streams - 5 minutes (seconds) -class McpHttpClientFactory(Protocol): - def __call__( + +class McpHttpClientFactory(Protocol): # pragma: no branch + def __call__( # pragma: no branch self, headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, @@ -68,7 +72,7 @@ def create_mcp_http_client( # Handle timeout if timeout is None: - kwargs["timeout"] = httpx.Timeout(30.0) + kwargs["timeout"] = httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT) else: kwargs["timeout"] = timeout @@ -77,7 +81,7 @@ def create_mcp_http_client( kwargs["headers"] = headers # Handle authentication - if auth is not None: + if auth is not None: # pragma: no cover kwargs["auth"] = auth return httpx.AsyncClient(**kwargs) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index b7f048bbab..d3290997e5 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -21,7 +21,7 @@ def normalize_token_type(cls, v: str | None) -> str | None: # Bearer is title-cased in the spec, so we normalize it # https://datatracker.ietf.org/doc/html/rfc6750#section-4 return v.title() - return v + return v # pragma: no cover class InvalidScopeError(Exception): @@ -41,13 +41,15 @@ class OAuthClientMetadata(BaseModel): for the full specification. """ - redirect_uris: list[AnyUrl] = Field(..., min_length=1) - # token_endpoint_auth_method: this implementation only supports none & - # client_secret_post; - # ie: we do not support client_secret_basic - token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"] | str] = [ + redirect_uris: list[AnyUrl] | None = Field(..., min_length=1) + # supported auth methods for the token endpoint + token_endpoint_auth_method: ( + Literal["none", "client_secret_post", "client_secret_basic", "private_key_jwt"] | None + ) = None + # supported grant_types of this implementation + grant_types: list[ + Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str + ] = [ "authorization_code", "refresh_token", ] @@ -75,17 +77,17 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: requested_scopes = requested_scope.split(" ") allowed_scopes = [] if self.scope is None else self.scope.split(" ") for scope in requested_scopes: - if scope not in allowed_scopes: + if scope not in allowed_scopes: # pragma: no branch raise InvalidScopeError(f"Client was not registered with scope {scope}") - return requested_scopes + return requested_scopes # pragma: no cover def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs - if redirect_uri not in self.redirect_uris: + if self.redirect_uris is None or redirect_uri not in self.redirect_uris: raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") return redirect_uri - elif len(self.redirect_uris) == 1: + elif self.redirect_uris is not None and len(self.redirect_uris) == 1: return self.redirect_uris[0] else: raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs") @@ -97,7 +99,7 @@ class OAuthClientInformationFull(OAuthClientMetadata): (client information plus metadata). """ - client_id: str + client_id: str | None = None client_secret: str | None = None client_id_issued_at: int | None = None client_secret_expires_at: int | None = None @@ -130,6 +132,7 @@ class OAuthMetadata(BaseModel): introspection_endpoint_auth_methods_supported: list[str] | None = None introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None code_challenge_methods_supported: list[str] | None = None + client_id_metadata_document_supported: bool | None = None class ProtectedResourceMetadata(BaseModel): diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 6d6300c9c8..8f3c542f22 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -1,5 +1,6 @@ -"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707).""" +"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636).""" +import time from urllib.parse import urlparse, urlsplit, urlunsplit from pydantic import AnyUrl, HttpUrl @@ -67,3 +68,18 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> configured_path += "/" return requested_path.startswith(configured_path) + + +def calculate_token_expiry(expires_in: int | str | None) -> float | None: + """Calculate token expiry timestamp from expires_in seconds. + + Args: + expires_in: Seconds until token expiration (may be string from some servers) + + Returns: + Unix timestamp when token expires, or None if no expiry specified + """ + if expires_in is None: + return None # pragma: no cover + # Defensive: handle servers that return expires_in as string + return time.time() + int(expires_in) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5f..5cf6588c9e 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,8 +1,13 @@ -from dataclasses import dataclass +""" +Request context for MCP handlers. +""" + +from dataclasses import dataclass, field from typing import Any, Generic from typing_extensions import TypeVar +from mcp.shared.message import CloseSSEStreamCallback from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams @@ -17,4 +22,11 @@ class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + # NOTE: This is typed as Any to avoid circular imports. The actual type is + # mcp.server.experimental.request_context.Experimental, but importing it here + # triggers mcp.server.__init__ -> fastmcp -> tools -> back to this module. + # The Server sets this to an Experimental instance at runtime. + experimental: Any = field(default=None) request: RequestT | None = None + close_sse_stream: CloseSSEStreamCallback | None = None + close_standalone_sse_stream: CloseSSEStreamCallback | None = None diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 97a1c09a9f..4943114912 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -1,4 +1,8 @@ -from mcp.types import ErrorData +from __future__ import annotations + +from typing import Any, cast + +from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData class McpError(Exception): @@ -12,3 +16,56 @@ def __init__(self, error: ErrorData): """Initialize McpError.""" super().__init__(error.message) self.error = error + + +class UrlElicitationRequiredError(McpError): + """ + Specialized error for when a tool requires URL mode elicitation(s) before proceeding. + + Servers can raise this error from tool handlers to indicate that the client + must complete one or more URL elicitations before the request can be processed. + + Example: + raise UrlElicitationRequiredError([ + ElicitRequestURLParams( + mode="url", + message="Authorization required for your files", + url="https://example.com/oauth/authorize", + elicitationId="auth-001" + ) + ]) + """ + + def __init__( + self, + elicitations: list[ElicitRequestURLParams], + message: str | None = None, + ): + """Initialize UrlElicitationRequiredError.""" + if message is None: + message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required" + + self._elicitations = elicitations + + error = ErrorData( + code=URL_ELICITATION_REQUIRED, + message=message, + data={"elicitations": [e.model_dump(by_alias=True, exclude_none=True) for e in elicitations]}, + ) + super().__init__(error) + + @property + def elicitations(self) -> list[ElicitRequestURLParams]: + """The list of URL elicitations required before the request can proceed.""" + return self._elicitations + + @classmethod + def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError: + """Reconstruct from an ErrorData received over the wire.""" + if error.code != URL_ELICITATION_REQUIRED: + raise ValueError(f"Expected error code {URL_ELICITATION_REQUIRED}, got {error.code}") + + data = cast(dict[str, Any], error.data or {}) + raw_elicitations = cast(list[dict[str, Any]], data.get("elicitations", [])) + elicitations = [ElicitRequestURLParams.model_validate(e) for e in raw_elicitations] + return cls(elicitations, error.message) diff --git a/src/mcp/shared/experimental/__init__.py b/src/mcp/shared/experimental/__init__.py new file mode 100644 index 0000000000..9b1b1479cb --- /dev/null +++ b/src/mcp/shared/experimental/__init__.py @@ -0,0 +1,7 @@ +""" +Pure experimental MCP features (no server dependencies). + +WARNING: These APIs are experimental and may change without notice. + +For server-integrated experimental features, use mcp.server.experimental. +""" diff --git a/src/mcp/shared/experimental/tasks/__init__.py b/src/mcp/shared/experimental/tasks/__init__.py new file mode 100644 index 0000000000..37d81af50b --- /dev/null +++ b/src/mcp/shared/experimental/tasks/__init__.py @@ -0,0 +1,12 @@ +""" +Pure task state management for MCP. + +WARNING: These APIs are experimental and may change without notice. + +Import directly from submodules: +- mcp.shared.experimental.tasks.store.TaskStore +- mcp.shared.experimental.tasks.context.TaskContext +- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore +- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue +- mcp.shared.experimental.tasks.helpers.is_terminal +""" diff --git a/src/mcp/shared/experimental/tasks/capabilities.py b/src/mcp/shared/experimental/tasks/capabilities.py new file mode 100644 index 0000000000..307fcdd6e5 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/capabilities.py @@ -0,0 +1,115 @@ +""" +Tasks capability checking utilities. + +This module provides functions for checking and requiring task-related +capabilities. All tasks capability logic is centralized here to keep +the main session code clean. + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.shared.exceptions import McpError +from mcp.types import ( + INVALID_REQUEST, + ClientCapabilities, + ClientTasksCapability, + ErrorData, +) + + +def check_tasks_capability( + required: ClientTasksCapability, + client: ClientTasksCapability, +) -> bool: + """ + Check if client's tasks capability matches the required capability. + + Args: + required: The capability being checked for + client: The client's declared capabilities + + Returns: + True if client has the required capability, False otherwise + """ + if required.requests is None: + return True + if client.requests is None: + return False + + # Check elicitation.create + if required.requests.elicitation is not None: + if client.requests.elicitation is None: + return False + if required.requests.elicitation.create is not None: + if client.requests.elicitation.create is None: + return False + + # Check sampling.createMessage + if required.requests.sampling is not None: + if client.requests.sampling is None: + return False + if required.requests.sampling.createMessage is not None: + if client.requests.sampling.createMessage is None: + return False + + return True + + +def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool: + """Check if capabilities include task-augmented elicitation support.""" + if caps.tasks is None: + return False + if caps.tasks.requests is None: + return False + if caps.tasks.requests.elicitation is None: + return False + return caps.tasks.requests.elicitation.create is not None + + +def has_task_augmented_sampling(caps: ClientCapabilities) -> bool: + """Check if capabilities include task-augmented sampling support.""" + if caps.tasks is None: + return False + if caps.tasks.requests is None: + return False + if caps.tasks.requests.sampling is None: + return False + return caps.tasks.requests.sampling.createMessage is not None + + +def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None: + """ + Raise McpError if client doesn't support task-augmented elicitation. + + Args: + client_caps: The client's declared capabilities, or None if not initialized + + Raises: + McpError: If client doesn't support task-augmented elicitation + """ + if client_caps is None or not has_task_augmented_elicitation(client_caps): + raise McpError( + ErrorData( + code=INVALID_REQUEST, + message="Client does not support task-augmented elicitation", + ) + ) + + +def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None: + """ + Raise McpError if client doesn't support task-augmented sampling. + + Args: + client_caps: The client's declared capabilities, or None if not initialized + + Raises: + McpError: If client doesn't support task-augmented sampling + """ + if client_caps is None or not has_task_augmented_sampling(client_caps): + raise McpError( + ErrorData( + code=INVALID_REQUEST, + message="Client does not support task-augmented sampling", + ) + ) diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py new file mode 100644 index 0000000000..12d159515c --- /dev/null +++ b/src/mcp/shared/experimental/tasks/context.py @@ -0,0 +1,101 @@ +""" +TaskContext - Pure task state management. + +This module provides TaskContext, which manages task state without any +server/session dependencies. It can be used standalone for distributed +workers or wrapped by ServerTaskContext for full server integration. +""" + +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task + + +class TaskContext: + """ + Pure task state management - no session dependencies. + + This class handles: + - Task state (status, result) + - Cancellation tracking + - Store interactions + + For server-integrated features (elicit, create_message, notifications), + use ServerTaskContext from mcp.server.experimental. + + Example (distributed worker): + async def worker_job(task_id: str): + store = RedisTaskStore(redis_url) + task = await store.get_task(task_id) + ctx = TaskContext(task=task, store=store) + + await ctx.update_status("Working...") + result = await do_work() + await ctx.complete(result) + """ + + def __init__(self, task: Task, store: TaskStore): + self._task = task + self._store = store + self._cancelled = False + + @property + def task_id(self) -> str: + """The task identifier.""" + return self._task.taskId + + @property + def task(self) -> Task: + """The current task state.""" + return self._task + + @property + def is_cancelled(self) -> bool: + """Whether cancellation has been requested.""" + return self._cancelled + + def request_cancellation(self) -> None: + """ + Request cancellation of this task. + + This sets is_cancelled=True. Task work should check this + periodically and exit gracefully if set. + """ + self._cancelled = True + + async def update_status(self, message: str) -> None: + """ + Update the task's status message. + + Args: + message: The new status message + """ + self._task = await self._store.update_task( + self.task_id, + status_message=message, + ) + + async def complete(self, result: Result) -> None: + """ + Mark the task as completed with the given result. + + Args: + result: The task result + """ + await self._store.store_result(self.task_id, result) + self._task = await self._store.update_task( + self.task_id, + status=TASK_STATUS_COMPLETED, + ) + + async def fail(self, error: str) -> None: + """ + Mark the task as failed with an error message. + + Args: + error: The error message + """ + self._task = await self._store.update_task( + self.task_id, + status=TASK_STATUS_FAILED, + status_message=error, + ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py new file mode 100644 index 0000000000..5c87f9ef87 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -0,0 +1,181 @@ +""" +Helper functions for pure task management. + +These helpers work with pure TaskContext and don't require server dependencies. +For server-integrated task helpers, use mcp.server.experimental. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from uuid import uuid4 + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + INVALID_PARAMS, + TASK_STATUS_CANCELLED, + TASK_STATUS_COMPLETED, + TASK_STATUS_FAILED, + TASK_STATUS_WORKING, + CancelTaskResult, + ErrorData, + Task, + TaskMetadata, + TaskStatus, +) + +# Metadata key for model-immediate-response (per MCP spec) +# Servers MAY include this in CreateTaskResult._meta to provide an immediate +# response string while the task executes in the background. +MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response" + +# Metadata key for associating requests with a task (per MCP spec) +RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task" + + +def is_terminal(status: TaskStatus) -> bool: + """ + Check if a task status represents a terminal state. + + Terminal states are those where the task has finished and will not change. + + Args: + status: The task status to check + + Returns: + True if the status is terminal (completed, failed, or cancelled) + """ + return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED) + + +async def cancel_task( + store: TaskStore, + task_id: str, +) -> CancelTaskResult: + """ + Cancel a task with spec-compliant validation. + + Per spec: "Receivers MUST reject cancellation of terminal status tasks + with -32602 (Invalid params)" + + This helper validates that the task exists and is not in a terminal state + before setting it to "cancelled". + + Args: + store: The task store + task_id: The task identifier to cancel + + Returns: + CancelTaskResult with the cancelled task state + + Raises: + McpError: With INVALID_PARAMS (-32602) if: + - Task does not exist + - Task is already in a terminal state (completed, failed, cancelled) + + Example: + @server.experimental.cancel_task() + async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult: + return await cancel_task(store, request.params.taskId) + """ + task = await store.get_task(task_id) + if task is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + + if is_terminal(task.status): + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Cannot cancel task in terminal state '{task.status}'", + ) + ) + + # Update task to cancelled status + cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED) + return CancelTaskResult(**cancelled_task.model_dump()) + + +def generate_task_id() -> str: + """Generate a unique task ID.""" + return str(uuid4()) + + +def create_task_state( + metadata: TaskMetadata, + task_id: str | None = None, +) -> Task: + """ + Create a Task object with initial state. + + This is a helper for TaskStore implementations. + + Args: + metadata: Task metadata + task_id: Optional task ID (generated if not provided) + + Returns: + A new Task in "working" status + """ + now = datetime.now(timezone.utc) + return Task( + taskId=task_id or generate_task_id(), + status=TASK_STATUS_WORKING, + createdAt=now, + lastUpdatedAt=now, + ttl=metadata.ttl, + pollInterval=500, # Default 500ms poll interval + ) + + +@asynccontextmanager +async def task_execution( + task_id: str, + store: TaskStore, +) -> AsyncIterator[TaskContext]: + """ + Context manager for safe task execution (pure, no server dependencies). + + Loads a task from the store and provides a TaskContext for the work. + If an unhandled exception occurs, the task is automatically marked as failed + and the exception is suppressed (since the failure is captured in task state). + + This is useful for distributed workers that don't have a server session. + + Args: + task_id: The task identifier to execute + store: The task store (must be accessible by the worker) + + Yields: + TaskContext for updating status and completing/failing the task + + Raises: + ValueError: If the task is not found in the store + + Example (distributed worker): + async def worker_process(task_id: str): + store = RedisTaskStore(redis_url) + async with task_execution(task_id, store) as ctx: + await ctx.update_status("Working...") + result = await do_work() + await ctx.complete(result) + """ + task = await store.get_task(task_id) + if task is None: + raise ValueError(f"Task {task_id} not found") + + ctx = TaskContext(task, store) + try: + yield ctx + except Exception as e: + # Auto-fail the task if an exception occurs and task isn't already terminal + # Exception is suppressed since failure is captured in task state + if not is_terminal(ctx.task.status): + await ctx.fail(str(e)) + # Don't re-raise - the failure is recorded in task state diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py new file mode 100644 index 0000000000..7b630ce6e2 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -0,0 +1,219 @@ +""" +In-memory implementation of TaskStore for demonstration purposes. + +This implementation stores all tasks in memory and provides automatic cleanup +based on the TTL duration specified in the task metadata using lazy expiration. + +Note: This is not suitable for production use as all data is lost on restart. +For production, consider implementing TaskStore with a database or distributed cache. +""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone + +import anyio + +from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import Result, Task, TaskMetadata, TaskStatus + + +@dataclass +class StoredTask: + """Internal storage representation of a task.""" + + task: Task + result: Result | None = None + # Time when this task should be removed (None = never) + expires_at: datetime | None = field(default=None) + + +class InMemoryTaskStore(TaskStore): + """ + A simple in-memory implementation of TaskStore. + + Features: + - Automatic TTL-based cleanup (lazy expiration) + - Thread-safe for single-process async use + - Pagination support for list_tasks + + Limitations: + - All data lost on restart + - Not suitable for distributed systems + - No persistence + + For production, implement TaskStore with Redis, PostgreSQL, etc. + """ + + def __init__(self, page_size: int = 10) -> None: + self._tasks: dict[str, StoredTask] = {} + self._page_size = page_size + self._update_events: dict[str, anyio.Event] = {} + + def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: + """Calculate expiry time from TTL in milliseconds.""" + if ttl_ms is None: + return None + return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) + + def _is_expired(self, stored: StoredTask) -> bool: + """Check if a task has expired.""" + if stored.expires_at is None: + return False + return datetime.now(timezone.utc) >= stored.expires_at + + def _cleanup_expired(self) -> None: + """Remove all expired tasks. Called lazily during access operations.""" + expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] + for task_id in expired_ids: + del self._tasks[task_id] + + async def create_task( + self, + metadata: TaskMetadata, + task_id: str | None = None, + ) -> Task: + """Create a new task with the given metadata.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + task = create_task_state(metadata, task_id) + + if task.taskId in self._tasks: + raise ValueError(f"Task with ID {task.taskId} already exists") + + stored = StoredTask( + task=task, + expires_at=self._calculate_expiry(metadata.ttl), + ) + self._tasks[task.taskId] = stored + + # Return a copy to prevent external modification + return Task(**task.model_dump()) + + async def get_task(self, task_id: str) -> Task | None: + """Get a task by ID.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + stored = self._tasks.get(task_id) + if stored is None: + return None + + # Return a copy to prevent external modification + return Task(**stored.task.model_dump()) + + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + status_message: str | None = None, + ) -> Task: + """Update a task's status and/or message.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + # Per spec: Terminal states MUST NOT transition to any other status + if status is not None and status != stored.task.status and is_terminal(stored.task.status): + raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") + + status_changed = False + if status is not None and stored.task.status != status: + stored.task.status = status + status_changed = True + + if status_message is not None: + stored.task.statusMessage = status_message + + # Update lastUpdatedAt on any change + stored.task.lastUpdatedAt = datetime.now(timezone.utc) + + # If task is now terminal and has TTL, reset expiry timer + if status is not None and is_terminal(status) and stored.task.ttl is not None: + stored.expires_at = self._calculate_expiry(stored.task.ttl) + + # Notify waiters if status changed + if status_changed: + await self.notify_update(task_id) + + return Task(**stored.task.model_dump()) + + async def store_result(self, task_id: str, result: Result) -> None: + """Store the result for a task.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + stored.result = result + + async def get_result(self, task_id: str) -> Result | None: + """Get the stored result for a task.""" + stored = self._tasks.get(task_id) + if stored is None: + return None + + return stored.result + + async def list_tasks( + self, + cursor: str | None = None, + ) -> tuple[list[Task], str | None]: + """List tasks with pagination.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + all_task_ids = list(self._tasks.keys()) + + start_index = 0 + if cursor is not None: + try: + cursor_index = all_task_ids.index(cursor) + start_index = cursor_index + 1 + except ValueError: + raise ValueError(f"Invalid cursor: {cursor}") + + page_task_ids = all_task_ids[start_index : start_index + self._page_size] + tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] + + # Determine next cursor + next_cursor = None + if start_index + self._page_size < len(all_task_ids) and page_task_ids: + next_cursor = page_task_ids[-1] + + return tasks, next_cursor + + async def delete_task(self, task_id: str) -> bool: + """Delete a task.""" + if task_id not in self._tasks: + return False + + del self._tasks[task_id] + return True + + async def wait_for_update(self, task_id: str) -> None: + """Wait until the task status changes.""" + if task_id not in self._tasks: + raise ValueError(f"Task with ID {task_id} not found") + + # Create a fresh event for waiting (anyio.Event can't be cleared) + self._update_events[task_id] = anyio.Event() + event = self._update_events[task_id] + await event.wait() + + async def notify_update(self, task_id: str) -> None: + """Signal that a task has been updated.""" + if task_id in self._update_events: + self._update_events[task_id].set() + + # --- Testing/debugging helpers --- + + def cleanup(self) -> None: + """Cleanup all tasks (useful for testing or graceful shutdown).""" + self._tasks.clear() + self._update_events.clear() + + def get_all_tasks(self) -> list[Task]: + """Get all tasks (useful for debugging). Returns copies to prevent modification.""" + self._cleanup_expired() + return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/message_queue.py b/src/mcp/shared/experimental/tasks/message_queue.py new file mode 100644 index 0000000000..69b6609887 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/message_queue.py @@ -0,0 +1,241 @@ +""" +TaskMessageQueue - FIFO queue for task-related messages. + +This implements the core message queue pattern from the MCP Tasks spec. +When a handler needs to send a request (like elicitation) during a task-augmented +request, the message is enqueued instead of sent directly. Messages are delivered +to the client only through the `tasks/result` endpoint. + +This pattern enables: +1. Decoupling request handling from message delivery +2. Proper bidirectional communication via the tasks/result stream +3. Automatic status management (working <-> input_required) +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import anyio + +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId + + +@dataclass +class QueuedMessage: + """ + A message queued for delivery via tasks/result. + + Messages are stored with their type and a resolver for requests + that expect responses. + """ + + type: Literal["request", "notification"] + """Whether this is a request (expects response) or notification (one-way).""" + + message: JSONRPCRequest | JSONRPCNotification + """The JSON-RPC message to send.""" + + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """When the message was enqueued.""" + + resolver: Resolver[dict[str, Any]] | None = None + """Resolver to set when response arrives (only for requests).""" + + original_request_id: RequestId | None = None + """The original request ID used internally, for routing responses back.""" + + +class TaskMessageQueue(ABC): + """ + Abstract interface for task message queuing. + + This is a FIFO queue that stores messages to be delivered via `tasks/result`. + When a task-augmented handler calls elicit() or sends a notification, the + message is enqueued here instead of being sent directly to the client. + + The `tasks/result` handler then dequeues and sends these messages through + the transport, with `relatedRequestId` set to the tasks/result request ID + so responses are routed correctly. + + Implementations can use in-memory storage, Redis, etc. + """ + + @abstractmethod + async def enqueue(self, task_id: str, message: QueuedMessage) -> None: + """ + Add a message to the queue for a task. + + Args: + task_id: The task identifier + message: The message to enqueue + """ + + @abstractmethod + async def dequeue(self, task_id: str) -> QueuedMessage | None: + """ + Remove and return the next message from the queue. + + Args: + task_id: The task identifier + + Returns: + The next message, or None if queue is empty + """ + + @abstractmethod + async def peek(self, task_id: str) -> QueuedMessage | None: + """ + Return the next message without removing it. + + Args: + task_id: The task identifier + + Returns: + The next message, or None if queue is empty + """ + + @abstractmethod + async def is_empty(self, task_id: str) -> bool: + """ + Check if the queue is empty for a task. + + Args: + task_id: The task identifier + + Returns: + True if no messages are queued + """ + + @abstractmethod + async def clear(self, task_id: str) -> list[QueuedMessage]: + """ + Remove and return all messages from the queue. + + This is useful for cleanup when a task is cancelled or completed. + + Args: + task_id: The task identifier + + Returns: + All queued messages (may be empty) + """ + + @abstractmethod + async def wait_for_message(self, task_id: str) -> None: + """ + Wait until a message is available in the queue. + + This blocks until either: + 1. A message is enqueued for this task + 2. The wait is cancelled + + Args: + task_id: The task identifier + """ + + @abstractmethod + async def notify_message_available(self, task_id: str) -> None: + """ + Signal that a message is available for a task. + + This wakes up any coroutines waiting in wait_for_message(). + + Args: + task_id: The task identifier + """ + + +class InMemoryTaskMessageQueue(TaskMessageQueue): + """ + In-memory implementation of TaskMessageQueue. + + This is suitable for single-process servers. For distributed systems, + implement TaskMessageQueue with Redis, RabbitMQ, etc. + + Features: + - FIFO ordering per task + - Async wait for message availability + - Thread-safe for single-process async use + """ + + def __init__(self) -> None: + self._queues: dict[str, list[QueuedMessage]] = {} + self._events: dict[str, anyio.Event] = {} + + def _get_queue(self, task_id: str) -> list[QueuedMessage]: + """Get or create the queue for a task.""" + if task_id not in self._queues: + self._queues[task_id] = [] + return self._queues[task_id] + + async def enqueue(self, task_id: str, message: QueuedMessage) -> None: + """Add a message to the queue.""" + queue = self._get_queue(task_id) + queue.append(message) + # Signal that a message is available + await self.notify_message_available(task_id) + + async def dequeue(self, task_id: str) -> QueuedMessage | None: + """Remove and return the next message.""" + queue = self._get_queue(task_id) + if not queue: + return None + return queue.pop(0) + + async def peek(self, task_id: str) -> QueuedMessage | None: + """Return the next message without removing it.""" + queue = self._get_queue(task_id) + if not queue: + return None + return queue[0] + + async def is_empty(self, task_id: str) -> bool: + """Check if the queue is empty.""" + queue = self._get_queue(task_id) + return len(queue) == 0 + + async def clear(self, task_id: str) -> list[QueuedMessage]: + """Remove and return all messages.""" + queue = self._get_queue(task_id) + messages = list(queue) + queue.clear() + return messages + + async def wait_for_message(self, task_id: str) -> None: + """Wait until a message is available.""" + # Check if there are already messages + if not await self.is_empty(task_id): + return + + # Create a fresh event for waiting (anyio.Event can't be cleared) + self._events[task_id] = anyio.Event() + event = self._events[task_id] + + # Double-check after creating event (avoid race condition) + if not await self.is_empty(task_id): + return + + # Wait for a new message + await event.wait() + + async def notify_message_available(self, task_id: str) -> None: + """Signal that a message is available.""" + if task_id in self._events: + self._events[task_id].set() + + def cleanup(self, task_id: str | None = None) -> None: + """ + Clean up queues and events. + + Args: + task_id: If provided, clean up only this task. Otherwise clean up all. + """ + if task_id is not None: + self._queues.pop(task_id, None) + self._events.pop(task_id, None) + else: + self._queues.clear() + self._events.clear() diff --git a/src/mcp/shared/experimental/tasks/polling.py b/src/mcp/shared/experimental/tasks/polling.py new file mode 100644 index 0000000000..39db2e6b68 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/polling.py @@ -0,0 +1,45 @@ +""" +Shared polling utilities for task operations. + +This module provides generic polling logic that works for both client→server +and server→client task polling. + +WARNING: These APIs are experimental and may change without notice. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable + +import anyio + +from mcp.shared.experimental.tasks.helpers import is_terminal +from mcp.types import GetTaskResult + + +async def poll_until_terminal( + get_task: Callable[[str], Awaitable[GetTaskResult]], + task_id: str, + default_interval_ms: int = 500, +) -> AsyncIterator[GetTaskResult]: + """ + Poll a task until it reaches terminal status. + + This is a generic utility that works for both client→server and server→client + polling. The caller provides the get_task function appropriate for their direction. + + Args: + get_task: Async function that takes task_id and returns GetTaskResult + task_id: The task to poll + default_interval_ms: Fallback poll interval if server doesn't specify + + Yields: + GetTaskResult for each poll + """ + while True: + status = await get_task(task_id) + yield status + + if is_terminal(status.status): + break + + interval_ms = status.pollInterval if status.pollInterval is not None else default_interval_ms + await anyio.sleep(interval_ms / 1000) diff --git a/src/mcp/shared/experimental/tasks/resolver.py b/src/mcp/shared/experimental/tasks/resolver.py new file mode 100644 index 0000000000..f27425b2c6 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/resolver.py @@ -0,0 +1,60 @@ +""" +Resolver - An anyio-compatible future-like object for async result passing. + +This provides a simple way to pass a result (or exception) from one coroutine +to another without depending on asyncio.Future. +""" + +from typing import Generic, TypeVar, cast + +import anyio + +T = TypeVar("T") + + +class Resolver(Generic[T]): + """ + A simple resolver for passing results between coroutines. + + Unlike asyncio.Future, this works with any anyio-compatible async backend. + + Usage: + resolver: Resolver[str] = Resolver() + + # In one coroutine: + resolver.set_result("hello") + + # In another coroutine: + result = await resolver.wait() # returns "hello" + """ + + def __init__(self) -> None: + self._event = anyio.Event() + self._value: T | None = None + self._exception: BaseException | None = None + + def set_result(self, value: T) -> None: + """Set the result value and wake up waiters.""" + if self._event.is_set(): + raise RuntimeError("Resolver already completed") + self._value = value + self._event.set() + + def set_exception(self, exc: BaseException) -> None: + """Set an exception and wake up waiters.""" + if self._event.is_set(): + raise RuntimeError("Resolver already completed") + self._exception = exc + self._event.set() + + async def wait(self) -> T: + """Wait for the result and return it, or raise the exception.""" + await self._event.wait() + if self._exception is not None: + raise self._exception + # If we reach here, set_result() was called, so _value is set + return cast(T, self._value) + + def done(self) -> bool: + """Return True if the resolver has been completed.""" + return self._event.is_set() diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py new file mode 100644 index 0000000000..71fb4511b8 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/store.py @@ -0,0 +1,156 @@ +""" +TaskStore - Abstract interface for task state storage. +""" + +from abc import ABC, abstractmethod + +from mcp.types import Result, Task, TaskMetadata, TaskStatus + + +class TaskStore(ABC): + """ + Abstract interface for task state storage. + + This is a pure storage interface - it doesn't manage execution. + Implementations can use in-memory storage, databases, Redis, etc. + + All methods are async to support various backends. + """ + + @abstractmethod + async def create_task( + self, + metadata: TaskMetadata, + task_id: str | None = None, + ) -> Task: + """ + Create a new task. + + Args: + metadata: Task metadata (ttl, etc.) + task_id: Optional task ID. If None, implementation should generate one. + + Returns: + The created Task with status="working" + + Raises: + ValueError: If task_id already exists + """ + + @abstractmethod + async def get_task(self, task_id: str) -> Task | None: + """ + Get a task by ID. + + Args: + task_id: The task identifier + + Returns: + The Task, or None if not found + """ + + @abstractmethod + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + status_message: str | None = None, + ) -> Task: + """ + Update a task's status and/or message. + + Args: + task_id: The task identifier + status: New status (if changing) + status_message: New status message (if changing) + + Returns: + The updated Task + + Raises: + ValueError: If task not found + ValueError: If attempting to transition from a terminal status + (completed, failed, cancelled). Per spec, terminal states + MUST NOT transition to any other status. + """ + + @abstractmethod + async def store_result(self, task_id: str, result: Result) -> None: + """ + Store the result for a task. + + Args: + task_id: The task identifier + result: The result to store + + Raises: + ValueError: If task not found + """ + + @abstractmethod + async def get_result(self, task_id: str) -> Result | None: + """ + Get the stored result for a task. + + Args: + task_id: The task identifier + + Returns: + The stored Result, or None if not available + """ + + @abstractmethod + async def list_tasks( + self, + cursor: str | None = None, + ) -> tuple[list[Task], str | None]: + """ + List tasks with pagination. + + Args: + cursor: Optional cursor for pagination + + Returns: + Tuple of (tasks, next_cursor). next_cursor is None if no more pages. + """ + + @abstractmethod + async def delete_task(self, task_id: str) -> bool: + """ + Delete a task. + + Args: + task_id: The task identifier + + Returns: + True if deleted, False if not found + """ + + @abstractmethod + async def wait_for_update(self, task_id: str) -> None: + """ + Wait until the task status changes. + + This blocks until either: + 1. The task status changes + 2. The wait is cancelled + + Used by tasks/result to wait for task completion or status changes. + + Args: + task_id: The task identifier + + Raises: + ValueError: If task not found + """ + + @abstractmethod + async def notify_update(self, task_id: str) -> None: + """ + Signal that a task has been updated. + + This wakes up any coroutines waiting in wait_for_update(). + + Args: + task_id: The task identifier + """ diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 265d07c378..c7c6dbabc2 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -6,7 +6,6 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from datetime import timedelta from typing import Any import anyio @@ -49,7 +48,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS @asynccontextmanager async def create_connected_server_and_client_session( server: Server[Any] | FastMCP, - read_timeout_seconds: timedelta | None = None, + read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, @@ -62,7 +61,7 @@ async def create_connected_server_and_client_session( # TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport", # and we should expose a method in the `FastMCP` so we don't access a private attribute. - if isinstance(server, FastMCP): + if isinstance(server, FastMCP): # pragma: no cover server = server._mcp_server # type: ignore[reportPrivateUsage] async with create_client_server_memory_streams() as (client_streams, server_streams): @@ -94,5 +93,5 @@ async def create_connected_server_and_client_session( ) as client_session: await client_session.initialize() yield client_session - finally: + finally: # pragma: no cover tg.cancel_scope.cancel() diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 4b6df23eb6..81503eaaa7 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -14,6 +14,9 @@ ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] +# Callback type for closing SSE streams without terminating +CloseSSEStreamCallback = Callable[[], Awaitable[None]] + @dataclass class ClientMessageMetadata: @@ -30,6 +33,10 @@ class ServerMessageMetadata: related_request_id: RequestId | None = None # Request-specific context (e.g., headers, auth info) request_context: object | None = None + # Callback to close SSE stream for the current request without terminating + close_sse_stream: CloseSSEStreamCallback | None = None + # Callback to close the standalone GET SSE stream (for unsolicited notifications) + close_standalone_sse_stream: CloseSSEStreamCallback | None = None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 1ad81a779c..a230c58b45 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -48,7 +48,7 @@ def progress( ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], None, ]: - if ctx.meta is None or ctx.meta.progressToken is None: + if ctx.meta is None or ctx.meta.progressToken is None: # pragma: no cover raise ValueError("No progress token provided") progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total) diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py new file mode 100644 index 0000000000..31796157fe --- /dev/null +++ b/src/mcp/shared/response_router.py @@ -0,0 +1,63 @@ +""" +ResponseRouter - Protocol for pluggable response routing. + +This module defines a protocol for routing JSON-RPC responses to alternative +handlers before falling back to the default response stream mechanism. + +The primary use case is task-augmented requests: when a TaskSession enqueues +a request (like elicitation), the response needs to be routed back to the +waiting resolver instead of the normal response stream. + +Design: +- Protocol-based for testability and flexibility +- Returns bool to indicate if response was handled +- Supports both success responses and errors +""" + +from typing import Any, Protocol + +from mcp.types import ErrorData, RequestId + + +class ResponseRouter(Protocol): + """ + Protocol for routing responses to alternative handlers. + + Implementations check if they have a pending request for the given ID + and deliver the response/error to the appropriate handler. + + Example: + class TaskResultHandler(ResponseRouter): + def route_response(self, request_id, response): + resolver = self._pending_requests.pop(request_id, None) + if resolver: + resolver.set_result(response) + return True + return False + """ + + def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool: + """ + Try to route a response to a pending request handler. + + Args: + request_id: The JSON-RPC request ID from the response + response: The response result data + + Returns: + True if the response was handled, False otherwise + """ + ... # pragma: no cover + + def route_error(self, request_id: RequestId, error: ErrorData) -> bool: + """ + Try to route an error to a pending request handler. + + Args: + request_id: The JSON-RPC request ID from the error response + error: The error data + + Returns: + True if the error was handled, False otherwise + """ + ... # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4e774984d4..c807e291c4 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,7 +1,6 @@ import logging from collections.abc import Callable from contextlib import AsyncExitStack -from datetime import timedelta from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -13,6 +12,7 @@ from mcp.shared.exceptions import McpError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.response_router import ResponseRouter from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -46,7 +46,9 @@ class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" - async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + async def __call__( + self, progress: float, total: float | None, message: str | None + ) -> None: ... # pragma: no branch class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -105,11 +107,11 @@ def __exit__( ) -> None: """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: + if self._completed: # pragma: no branch self._on_complete(self) finally: self._entered = False - if not self._cancel_scope: + if not self._cancel_scope: # pragma: no cover raise RuntimeError("No active cancel scope") self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) @@ -121,11 +123,11 @@ async def respond(self, response: SendResultT | ErrorData) -> None: RuntimeError: If not used within a context manager AssertionError: If request was already responded to """ - if not self._entered: + if not self._entered: # pragma: no cover raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" - if not self.cancelled: + if not self.cancelled: # pragma: no branch self._completed = True await self._session._send_response( # type: ignore[reportPrivateUsage] @@ -134,9 +136,9 @@ async def respond(self, response: SendResultT | ErrorData) -> None: async def cancel(self) -> None: """Cancel this request and mark it as completed.""" - if not self._entered: + if not self._entered: # pragma: no cover raise RuntimeError("RequestResponder must be used as a context manager") - if not self._cancel_scope: + if not self._cancel_scope: # pragma: no cover raise RuntimeError("No active cancel scope") self._cancel_scope.cancel() @@ -148,11 +150,11 @@ async def cancel(self) -> None: ) @property - def in_flight(self) -> bool: + def in_flight(self) -> bool: # pragma: no cover return not self._completed and not self.cancelled @property - def cancelled(self) -> bool: + def cancelled(self) -> bool: # pragma: no cover return self._cancel_scope.cancel_called @@ -177,6 +179,7 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _response_routers: list["ResponseRouter"] def __init__( self, @@ -185,7 +188,7 @@ def __init__( receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out - read_timeout_seconds: timedelta | None = None, + read_timeout_seconds: float | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -196,8 +199,24 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._response_routers = [] self._exit_stack = AsyncExitStack() + def add_response_router(self, router: ResponseRouter) -> None: + """ + Register a response router to handle responses for non-standard requests. + + Response routers are checked in order before falling back to the default + response stream mechanism. This is used by TaskResultHandler to route + responses for queued task requests back to their resolvers. + + WARNING: This is an experimental API that may change without notice. + + Args: + router: A ResponseRouter implementation + """ + self._response_routers.append(router) + async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() @@ -221,7 +240,7 @@ async def send_request( self, request: SendRequestT, result_type: type[ReceiveResultT], - request_read_timeout_seconds: timedelta | None = None, + request_read_timeout_seconds: float | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: @@ -241,11 +260,11 @@ async def send_request( # Set up progress token if progress callback is provided request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: + if progress_callback is not None: # pragma: no cover # Use request_id as progress token if "params" not in request_data: request_data["params"] = {} - if "_meta" not in request_data["params"]: + if "_meta" not in request_data["params"]: # pragma: no branch request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id # Store the callback for this request @@ -262,10 +281,10 @@ async def send_request( # request read timeout takes precedence over session read timeout timeout = None - if request_read_timeout_seconds is not None: - timeout = request_read_timeout_seconds.total_seconds() - elif self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() + if request_read_timeout_seconds is not None: # pragma: no cover + timeout = request_read_timeout_seconds + elif self._session_read_timeout_seconds is not None: # pragma: no cover + timeout = self._session_read_timeout_seconds try: with anyio.fail_after(timeout): @@ -308,7 +327,7 @@ async def send_notification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage( + session_message = SessionMessage( # pragma: no cover message=JSONRPCMessage(jsonrpc_notification), metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) @@ -335,7 +354,7 @@ async def _receive_loop(self) -> None: ): try: async for message in self._read_stream: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover await self._handle_incoming(message) elif isinstance(message.message.root, JSONRPCRequest): try: @@ -382,11 +401,11 @@ async def _receive_loop(self) -> None: # Handle cancellation notifications if isinstance(notification.root, CancelledNotification): cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: + if cancelled_id in self._in_flight: # pragma: no branch await self._in_flight[cancelled_id].cancel() else: # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): + if isinstance(notification.root, ProgressNotification): # pragma: no cover progress_token = notification.root.params.progressToken # If there is a progress callback for this token, # call it with the progress information @@ -405,26 +424,20 @@ async def _receive_loop(self) -> None: ) await self._received_notification(notification) await self._handle_incoming(notification) - except Exception as e: + except Exception as e: # pragma: no cover # For other validation errors, log and continue logging.warning( f"Failed to validate notification: {e}. Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: - await self._handle_incoming( - RuntimeError(f"Received response with an unknown request ID: {message}") - ) + await self._handle_response(message) except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and # crash the server's task group. - logging.debug("Read stream closed by client") - except Exception as e: + logging.debug("Read stream closed by client") # pragma: no cover + except Exception as e: # pragma: no cover # Other exceptions are not expected and should be logged. We purposefully # catch all exceptions here to avoid crashing the server. logging.exception(f"Unhandled exception in receive loop: {e}") @@ -436,11 +449,71 @@ async def _receive_loop(self) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: + except Exception: # pragma: no cover # Stream might already be closed pass self._response_streams.clear() + def _normalize_request_id(self, response_id: RequestId) -> RequestId: + """ + Normalize a response ID to match how request IDs are stored. + + Since the client always sends integer IDs, we normalize string IDs + to integers when possible. This matches the TypeScript SDK approach: + https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861 + + Args: + response_id: The response ID from the incoming message. + + Returns: + The normalized ID (int if possible, otherwise original value). + """ + if isinstance(response_id, str): + try: + return int(response_id) + except ValueError: + logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") + return response_id + + async def _handle_response(self, message: SessionMessage) -> None: + """ + Handle an incoming response or error message. + + Checks response routers first (e.g., for task-related responses), + then falls back to the normal response stream mechanism. + """ + root = message.message.root + + # This check is always true at runtime: the caller (_receive_loop) only invokes + # this method in the else branch after checking for JSONRPCRequest and + # JSONRPCNotification. However, the type checker can't infer this from the + # method signature, so we need this guard for type narrowing. + if not isinstance(root, JSONRPCResponse | JSONRPCError): + return # pragma: no cover + + # Normalize response ID to handle type mismatches (e.g., "0" vs 0) + response_id = self._normalize_request_id(root.id) + + # First, check response routers (e.g., TaskResultHandler) + if isinstance(root, JSONRPCError): + # Route error to routers + for router in self._response_routers: + if router.route_error(response_id, root.error): + return # Handled + else: + # Route success response to routers + response_data: dict[str, Any] = root.result or {} + for router in self._response_routers: + if router.route_response(response_id, response_data): + return # Handled + + # Fall back to normal response streams + stream = self._response_streams.pop(response_id, None) + if stream: # pragma: no cover + await stream.send(root) + else: # pragma: no cover + await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) + async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ Can be overridden by subclasses to handle a request without needing to @@ -473,4 +546,4 @@ async def _handle_incoming( req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" - pass + pass # pragma: no cover diff --git a/src/mcp/shared/tool_name_validation.py b/src/mcp/shared/tool_name_validation.py new file mode 100644 index 0000000000..f35efa5a61 --- /dev/null +++ b/src/mcp/shared/tool_name_validation.py @@ -0,0 +1,129 @@ +"""Tool name validation utilities according to SEP-986. + +Tool names SHOULD be between 1 and 128 characters in length (inclusive). +Tool names are case-sensitive. +Allowed characters: uppercase and lowercase ASCII letters (A-Z, a-z), +digits (0-9), underscore (_), dash (-), and dot (.). +Tool names SHOULD NOT contain spaces, commas, or other special characters. + +See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + +# Regular expression for valid tool names according to SEP-986 specification +TOOL_NAME_REGEX = re.compile(r"^[A-Za-z0-9._-]{1,128}$") + +# SEP reference URL for warning messages +SEP_986_URL = "https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names" + + +@dataclass +class ToolNameValidationResult: + """Result of tool name validation. + + Attributes: + is_valid: Whether the tool name conforms to SEP-986 requirements. + warnings: List of warning messages for non-conforming aspects. + """ + + is_valid: bool + warnings: list[str] = field(default_factory=lambda: []) + + +def validate_tool_name(name: str) -> ToolNameValidationResult: + """Validate a tool name according to the SEP-986 specification. + + Args: + name: The tool name to validate. + + Returns: + ToolNameValidationResult containing validation status and any warnings. + """ + warnings: list[str] = [] + + # Check for empty name + if not name: + return ToolNameValidationResult( + is_valid=False, + warnings=["Tool name cannot be empty"], + ) + + # Check length + if len(name) > 128: + return ToolNameValidationResult( + is_valid=False, + warnings=[f"Tool name exceeds maximum length of 128 characters (current: {len(name)})"], + ) + + # Check for problematic patterns (warnings, not validation failures) + if " " in name: + warnings.append("Tool name contains spaces, which may cause parsing issues") + + if "," in name: + warnings.append("Tool name contains commas, which may cause parsing issues") + + # Check for potentially confusing leading/trailing characters + if name.startswith("-") or name.endswith("-"): + warnings.append("Tool name starts or ends with a dash, which may cause parsing issues in some contexts") + + if name.startswith(".") or name.endswith("."): + warnings.append("Tool name starts or ends with a dot, which may cause parsing issues in some contexts") + + # Check for invalid characters + if not TOOL_NAME_REGEX.match(name): + # Find all invalid characters (unique, preserving order) + invalid_chars: list[str] = [] + seen: set[str] = set() + for char in name: + if not re.match(r"[A-Za-z0-9._-]", char) and char not in seen: + invalid_chars.append(char) + seen.add(char) + + warnings.append(f"Tool name contains invalid characters: {', '.join(repr(c) for c in invalid_chars)}") + warnings.append("Allowed characters are: A-Z, a-z, 0-9, underscore (_), dash (-), and dot (.)") + + return ToolNameValidationResult(is_valid=False, warnings=warnings) + + return ToolNameValidationResult(is_valid=True, warnings=warnings) + + +def issue_tool_name_warning(name: str, warnings: list[str]) -> None: + """Log warnings for non-conforming tool names. + + Args: + name: The tool name that triggered the warnings. + warnings: List of warning messages to log. + """ + if not warnings: + return + + logger.warning(f'Tool name validation warning for "{name}":') + for warning in warnings: + logger.warning(f" - {warning}") + logger.warning("Tool registration will proceed, but this may cause compatibility issues.") + logger.warning("Consider updating the tool name to conform to the MCP tool naming standard.") + logger.warning(f"See SEP-986 ({SEP_986_URL}) for more details.") + + +def validate_and_warn_tool_name(name: str) -> bool: + """Validate a tool name and issue warnings for non-conforming names. + + This is the primary entry point for tool name validation. It validates + the name and logs any warnings via the logging module. + + Args: + name: The tool name to validate. + + Returns: + True if the name is valid, False otherwise. + """ + result = validate_tool_name(name) + issue_tool_name_warning(name, result.warnings) + return result.is_valid diff --git a/src/mcp/shared/version.py b/src/mcp/shared/version.py index 23c46d04be..d2a1e462d4 100644 --- a/src/mcp/shared/version.py +++ b/src/mcp/shared/version.py @@ -1,3 +1,3 @@ from mcp.types import LATEST_PROTOCOL_VERSION -SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] +SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", "2025-06-18", LATEST_PROTOCOL_VERSION] diff --git a/src/mcp/types.py b/src/mcp/types.py index 8713227404..654c00660b 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,5 +1,6 @@ from collections.abc import Callable -from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar +from datetime import datetime +from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints @@ -23,7 +24,7 @@ not separate types in the schema. """ -LATEST_PROTOCOL_VERSION = "2025-06-18" +LATEST_PROTOCOL_VERSION = "2025-11-25" """ The default negotiated version of the Model Context Protocol when no version is specified. @@ -39,6 +40,23 @@ RequestId = Annotated[int, Field(strict=True)] | str AnyFunction: TypeAlias = Callable[..., Any] +TaskExecutionMode = Literal["forbidden", "optional", "required"] +TASK_FORBIDDEN: Final[Literal["forbidden"]] = "forbidden" +TASK_OPTIONAL: Final[Literal["optional"]] = "optional" +TASK_REQUIRED: Final[Literal["required"]] = "required" + + +class TaskMetadata(BaseModel): + """ + Metadata for augmenting a request with task execution. + Include this in the `task` field of the request parameters. + """ + + model_config = ConfigDict(extra="allow") + + ttl: Annotated[int, Field(strict=True)] | None = None + """Requested duration in milliseconds to retain task from creation.""" + class RequestParams(BaseModel): class Meta(BaseModel): @@ -52,6 +70,16 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") + task: TaskMetadata | None = None + """ + If specified, the caller is requesting task-augmented execution for this request. + The request will return a CreateTaskResult immediately, and the actual result can be + retrieved later via tasks/result. + + Task augmentation is subject to capability negotiation - receivers MUST declare support + for task augmentation of specific request types in their capabilities. + """ + meta: Meta | None = Field(alias="_meta", default=None) @@ -146,6 +174,10 @@ class JSONRPCResponse(BaseModel): model_config = ConfigDict(extra="allow") +# MCP-specific error codes in the range [-32000, -32099] +URL_ELICITATION_REQUIRED = -32042 +"""Error code indicating that a URL mode elicitation is required before the request can be processed.""" + # SDK error codes CONNECTION_CLOSED = -32000 # REQUEST_TIMEOUT = -32001 # the typescript sdk uses this @@ -250,17 +282,137 @@ class RootsCapability(BaseModel): model_config = ConfigDict(extra="allow") -class SamplingCapability(BaseModel): - """Capability for sampling operations.""" +class SamplingContextCapability(BaseModel): + """ + Capability for context inclusion during sampling. + + Indicates support for non-'none' values in the includeContext parameter. + SOFT-DEPRECATED: New implementations should use tools parameter instead. + """ + + model_config = ConfigDict(extra="allow") + + +class SamplingToolsCapability(BaseModel): + """ + Capability indicating support for tool calling during sampling. + + When present in ClientCapabilities.sampling, indicates that the client + supports the tools and toolChoice parameters in sampling requests. + """ + + model_config = ConfigDict(extra="allow") + + +class FormElicitationCapability(BaseModel): + """Capability for form mode elicitation.""" + + model_config = ConfigDict(extra="allow") + + +class UrlElicitationCapability(BaseModel): + """Capability for URL mode elicitation.""" model_config = ConfigDict(extra="allow") class ElicitationCapability(BaseModel): - """Capability for elicitation operations.""" + """Capability for elicitation operations. + + Clients must support at least one mode (form or url). + """ + + form: FormElicitationCapability | None = None + """Present if the client supports form mode elicitation.""" + + url: UrlElicitationCapability | None = None + """Present if the client supports URL mode elicitation.""" + + model_config = ConfigDict(extra="allow") + + +class SamplingCapability(BaseModel): + """ + Sampling capability structure, allowing fine-grained capability advertisement. + """ + + context: SamplingContextCapability | None = None + """ + Present if the client supports non-'none' values for includeContext parameter. + SOFT-DEPRECATED: New implementations should use tools parameter instead. + """ + tools: SamplingToolsCapability | None = None + """ + Present if the client supports tools and toolChoice parameters in sampling requests. + Presence indicates full tool calling support during sampling. + """ + model_config = ConfigDict(extra="allow") + + +class TasksListCapability(BaseModel): + """Capability for tasks listing operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksCancelCapability(BaseModel): + """Capability for tasks cancel operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksCreateMessageCapability(BaseModel): + """Capability for tasks create messages.""" + + model_config = ConfigDict(extra="allow") + + +class TasksSamplingCapability(BaseModel): + """Capability for tasks sampling operations.""" + + model_config = ConfigDict(extra="allow") + + createMessage: TasksCreateMessageCapability | None = None + + +class TasksCreateElicitationCapability(BaseModel): + """Capability for tasks create elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksElicitationCapability(BaseModel): + """Capability for tasks elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + create: TasksCreateElicitationCapability | None = None + + +class ClientTasksRequestsCapability(BaseModel): + """Capability for tasks requests operations.""" + + model_config = ConfigDict(extra="allow") + + sampling: TasksSamplingCapability | None = None + + elicitation: TasksElicitationCapability | None = None + + +class ClientTasksCapability(BaseModel): + """Capability for client tasks operations.""" model_config = ConfigDict(extra="allow") + list: TasksListCapability | None = None + """Whether this client supports tasks/list.""" + + cancel: TasksCancelCapability | None = None + """Whether this client supports tasks/cancel.""" + + requests: ClientTasksRequestsCapability | None = None + """Specifies which request types can be augmented with tasks.""" + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" @@ -268,11 +420,17 @@ class ClientCapabilities(BaseModel): experimental: dict[str, dict[str, Any]] | None = None """Experimental, non-standard capabilities that the client supports.""" sampling: SamplingCapability | None = None - """Present if the client supports sampling from an LLM.""" + """ + Present if the client supports sampling from an LLM. + Can contain fine-grained capabilities like context and tools support. + """ elicitation: ElicitationCapability | None = None """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" + tasks: ClientTasksCapability | None = None + """Present if the client supports task-augmented requests.""" + model_config = ConfigDict(extra="allow") @@ -314,6 +472,37 @@ class CompletionsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TasksCallCapability(BaseModel): + """Capability for tasks call operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksToolsCapability(BaseModel): + """Capability for tasks tools operations.""" + + model_config = ConfigDict(extra="allow") + call: TasksCallCapability | None = None + + +class ServerTasksRequestsCapability(BaseModel): + """Capability for tasks requests operations.""" + + model_config = ConfigDict(extra="allow") + + tools: TasksToolsCapability | None = None + + +class ServerTasksCapability(BaseModel): + """Capability for server tasks operations.""" + + model_config = ConfigDict(extra="allow") + + list: TasksListCapability | None = None + cancel: TasksCancelCapability | None = None + requests: ServerTasksRequestsCapability | None = None + + class ServerCapabilities(BaseModel): """Capabilities that a server may support.""" @@ -329,7 +518,154 @@ class ServerCapabilities(BaseModel): """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" + tasks: ServerTasksCapability | None = None + """Present if the server supports task-augmented requests.""" + model_config = ConfigDict(extra="allow") + + +TaskStatus = Literal["working", "input_required", "completed", "failed", "cancelled"] + +# Task status constants +TASK_STATUS_WORKING: Final[Literal["working"]] = "working" +TASK_STATUS_INPUT_REQUIRED: Final[Literal["input_required"]] = "input_required" +TASK_STATUS_COMPLETED: Final[Literal["completed"]] = "completed" +TASK_STATUS_FAILED: Final[Literal["failed"]] = "failed" +TASK_STATUS_CANCELLED: Final[Literal["cancelled"]] = "cancelled" + + +class RelatedTaskMetadata(BaseModel): + """ + Metadata for associating messages with a task. + + Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. + """ + model_config = ConfigDict(extra="allow") + taskId: str + """The task identifier this message is associated with.""" + + +class Task(BaseModel): + """Data associated with a task.""" + + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier.""" + + status: TaskStatus + """Current task state.""" + + statusMessage: str | None = None + """ + Optional human-readable message describing the current task state. + This can provide context for any status, including: + - Reasons for "cancelled" status + - Summaries for "completed" status + - Diagnostic information for "failed" status (e.g., error details, what went wrong) + """ + + createdAt: datetime # Pydantic will enforce ISO 8601 and re-serialize as a string later + """ISO 8601 timestamp when the task was created.""" + + lastUpdatedAt: datetime + """ISO 8601 timestamp when the task was last updated.""" + + ttl: Annotated[int, Field(strict=True)] | None + """Actual retention duration from creation in milliseconds, null for unlimited.""" + + pollInterval: Annotated[int, Field(strict=True)] | None = None + """Suggested polling interval in milliseconds.""" + + +class CreateTaskResult(Result): + """A response to a task-augmented request.""" + + task: Task + + +class GetTaskRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + taskId: str + """The task identifier to query.""" + + +class GetTaskRequest(Request[GetTaskRequestParams, Literal["tasks/get"]]): + """A request to retrieve the state of a task.""" + + method: Literal["tasks/get"] = "tasks/get" + + params: GetTaskRequestParams + + +class GetTaskResult(Result, Task): + """The response to a tasks/get request.""" + + +class GetTaskPayloadRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier to retrieve results for.""" + + +class GetTaskPayloadRequest(Request[GetTaskPayloadRequestParams, Literal["tasks/result"]]): + """A request to retrieve the result of a completed task.""" + + method: Literal["tasks/result"] = "tasks/result" + params: GetTaskPayloadRequestParams + + +class GetTaskPayloadResult(Result): + """ + The response to a tasks/result request. + The structure matches the result type of the original request. + For example, a tools/call task would return the CallToolResult structure. + """ + + +class CancelTaskRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier to cancel.""" + + +class CancelTaskRequest(Request[CancelTaskRequestParams, Literal["tasks/cancel"]]): + """A request to cancel a task.""" + + method: Literal["tasks/cancel"] = "tasks/cancel" + params: CancelTaskRequestParams + + +class CancelTaskResult(Result, Task): + """The response to a tasks/cancel request.""" + + +class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): + """A request to retrieve a list of tasks.""" + + method: Literal["tasks/list"] = "tasks/list" + + +class ListTasksResult(PaginatedResult): + """The response to a tasks/list request.""" + + tasks: list[Task] + + +class TaskStatusNotificationParams(NotificationParams, Task): + """Parameters for a `notifications/tasks/status` notification.""" + + +class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): + """ + An optional notification from the receiver to the requestor, informing them that a task's status has changed. + Receivers are not required to send these notifications + """ + + method: Literal["notifications/tasks/status"] = "notifications/tasks/status" + params: TaskStatusNotificationParams class InitializeRequestParams(RequestParams): @@ -742,13 +1078,101 @@ class AudioContent(BaseModel): model_config = ConfigDict(extra="allow") +class ToolUseContent(BaseModel): + """ + Content representing an assistant's request to invoke a tool. + + This content type appears in assistant messages when the LLM wants to call a tool + during sampling. The server should execute the tool and return a ToolResultContent + in the next user message. + """ + + type: Literal["tool_use"] + """Discriminator for tool use content.""" + + name: str + """The name of the tool to invoke. Must match a tool name from the request's tools array.""" + + id: str + """Unique identifier for this tool call, used to correlate with ToolResultContent.""" + + input: dict[str, Any] + """Arguments to pass to the tool. Must conform to the tool's inputSchema.""" + + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ToolResultContent(BaseModel): + """ + Content representing the result of a tool execution. + + This content type appears in user messages as a response to a ToolUseContent + from the assistant. It contains the output of executing the requested tool. + """ + + type: Literal["tool_result"] + """Discriminator for tool result content.""" + + toolUseId: str + """The unique identifier that corresponds to the tool call's id field.""" + + content: list["ContentBlock"] = [] + """ + A list of content objects representing the tool result. + Defaults to empty list if not provided. + """ + + structuredContent: dict[str, Any] | None = None + """ + Optional structured tool output that matches the tool's outputSchema (if defined). + """ + + isError: bool | None = None + """Whether the tool execution resulted in an error.""" + + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +SamplingMessageContentBlock: TypeAlias = TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent +"""Content block types allowed in sampling messages.""" + +SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent +"""Basic content types for sampling responses (without tool use). +Used for backwards-compatible CreateMessageResult when tools are not used.""" + + class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" role: Role - content: TextContent | ImageContent | AudioContent + content: SamplingMessageContentBlock | list[SamplingMessageContentBlock] + """ + Message content. Can be a single content block or an array of content blocks + for multi-modal messages and tool interactions. + """ + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") + @property + def content_as_list(self) -> list[SamplingMessageContentBlock]: + """Returns the content as a list of content blocks, regardless of whether + it was originally a single block or a list.""" + return self.content if isinstance(self.content, list) else [self.content] + class EmbeddedResource(BaseModel): """ @@ -865,9 +1289,29 @@ class ToolAnnotations(BaseModel): of a memory tool is not. Default: true """ + model_config = ConfigDict(extra="allow") +class ToolExecution(BaseModel): + """Execution-related properties for a tool.""" + + model_config = ConfigDict(extra="allow") + + taskSupport: TaskExecutionMode | None = None + """ + Indicates whether this tool supports task-augmented execution. + This allows clients to handle long-running operations through polling + the task system. + + - "forbidden": Tool does not support task-augmented execution (default when absent) + - "optional": Tool may support task-augmented execution + - "required": Tool requires task-augmented execution + + Default: "forbidden" + """ + + class Tool(BaseMetadata): """Definition for a tool the client can call.""" @@ -889,6 +1333,9 @@ class Tool(BaseMetadata): See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) for notes on _meta usage. """ + + execution: ToolExecution | None = None + model_config = ConfigDict(extra="allow") @@ -1035,6 +1482,25 @@ class ModelPreferences(BaseModel): model_config = ConfigDict(extra="allow") +class ToolChoice(BaseModel): + """ + Controls tool usage behavior during sampling. + + Allows the server to specify whether and how the LLM should use tools + in its response. + """ + + mode: Literal["auto", "required", "none"] | None = None + """ + Controls when tools are used: + - "auto": Model decides whether to use tools (default) + - "required": Model MUST use at least one tool before completing + - "none": Model should not use tools + """ + + model_config = ConfigDict(extra="allow") + + class CreateMessageRequestParams(RequestParams): """Parameters for creating a message.""" @@ -1057,6 +1523,16 @@ class CreateMessageRequestParams(RequestParams): stopSequences: list[str] | None = None metadata: dict[str, Any] | None = None """Optional metadata to pass through to the LLM provider.""" + tools: list["Tool"] | None = None + """ + Tool definitions for the LLM to use during sampling. + Requires clientCapabilities.sampling.tools to be present. + """ + toolChoice: ToolChoice | None = None + """ + Controls tool usage behavior. + Requires clientCapabilities.sampling.tools and the tools parameter to be present. + """ model_config = ConfigDict(extra="allow") @@ -1067,20 +1543,54 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling params: CreateMessageRequestParams -StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str +StopReason = Literal["endTurn", "stopSequence", "maxTokens", "toolUse"] | str class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server.""" + """The client's response to a sampling/create_message request from the server. + + This is the backwards-compatible version that returns single content (no arrays). + Used when the request does not include tools. + """ role: Role - content: TextContent | ImageContent | AudioContent + """The role of the message sender (typically 'assistant' for LLM responses).""" + content: SamplingContent + """Response content. Single content block (text, image, or audio).""" model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None """The reason why sampling stopped, if known.""" +class CreateMessageResultWithTools(Result): + """The client's response to a sampling/create_message request when tools were provided. + + This version supports array content for tool use flows. + """ + + role: Role + """The role of the message sender (typically 'assistant' for LLM responses).""" + content: SamplingMessageContentBlock | list[SamplingMessageContentBlock] + """ + Response content. May be a single content block or an array. + May include ToolUseContent if stopReason is 'toolUse'. + """ + model: str + """The name of the model that generated the message.""" + stopReason: StopReason | None = None + """ + The reason why sampling stopped, if known. + 'toolUse' indicates the model wants to use a tool. + """ + + @property + def content_as_list(self) -> list[SamplingMessageContentBlock]: + """Returns the content as a list of content blocks, regardless of whether + it was originally a single block or a list.""" + return self.content if isinstance(self.content, list) else [self.content] + + class ResourceTemplateReference(BaseModel): """A reference to a resource or resource template definition.""" @@ -1230,10 +1740,17 @@ class RootsListChangedNotification( class CancelledNotificationParams(NotificationParams): """Parameters for cancellation notifications.""" - requestId: RequestId - """The ID of the request to cancel.""" + requestId: RequestId | None = None + """ + The ID of the request to cancel. + + This MUST correspond to the ID of a request previously issued in the same direction. + This MUST be provided for cancelling non-task requests. + This MUST NOT be used for cancelling tasks (use the `tasks/cancel` request instead). + """ reason: str | None = None """An optional string describing the reason for the cancellation.""" + model_config = ConfigDict(extra="allow") @@ -1247,29 +1764,67 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n params: CancelledNotificationParams -class ClientRequest( - RootModel[ - PingRequest - | InitializeRequest - | CompleteRequest - | SetLevelRequest - | GetPromptRequest - | ListPromptsRequest - | ListResourcesRequest - | ListResourceTemplatesRequest - | ReadResourceRequest - | SubscribeRequest - | UnsubscribeRequest - | CallToolRequest - | ListToolsRequest - ] +class ElicitCompleteNotificationParams(NotificationParams): + """Parameters for elicitation completion notifications.""" + + elicitationId: str + """The unique identifier of the elicitation that was completed.""" + + model_config = ConfigDict(extra="allow") + + +class ElicitCompleteNotification( + Notification[ElicitCompleteNotificationParams, Literal["notifications/elicitation/complete"]] ): + """ + A notification from the server to the client, informing it that a URL mode + elicitation has been completed. + + Clients MAY use the notification to automatically retry requests that received a + URLElicitationRequiredError, update the user interface, or otherwise continue + an interaction. However, because delivery of the notification is not guaranteed, + clients must not wait indefinitely for a notification from the server. + """ + + method: Literal["notifications/elicitation/complete"] = "notifications/elicitation/complete" + params: ElicitCompleteNotificationParams + + +ClientRequestType: TypeAlias = ( + PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ListResourceTemplatesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | CancelTaskRequest +) + + +class ClientRequest(RootModel[ClientRequestType]): pass -class ClientNotification( - RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] -): +ClientNotificationType: TypeAlias = ( + CancelledNotification + | ProgressNotification + | InitializedNotification + | RootsListChangedNotification + | TaskStatusNotification +) + + +class ClientNotification(RootModel[ClientNotificationType]): pass @@ -1278,14 +1833,58 @@ class ClientNotification( """Schema for elicitation requests.""" -class ElicitRequestParams(RequestParams): - """Parameters for elicitation requests.""" +class ElicitRequestFormParams(RequestParams): + """Parameters for form mode elicitation requests. + + Form mode collects non-sensitive information from the user via an in-band form + rendered by the client. + """ + + mode: Literal["form"] = "form" + """The elicitation mode (always "form" for this type).""" message: str + """The message to present to the user describing what information is being requested.""" + requestedSchema: ElicitRequestedSchema + """ + A restricted subset of JSON Schema defining the structure of expected response. + Only top-level properties are allowed, without nesting. + """ + model_config = ConfigDict(extra="allow") +class ElicitRequestURLParams(RequestParams): + """Parameters for URL mode elicitation requests. + + URL mode directs users to external URLs for sensitive out-of-band interactions + like OAuth flows, credential collection, or payment processing. + """ + + mode: Literal["url"] = "url" + """The elicitation mode (always "url" for this type).""" + + message: str + """The message to present to the user explaining why the interaction is needed.""" + + url: str + """The URL that the user should navigate to.""" + + elicitationId: str + """ + The ID of the elicitation, which must be unique within the context of the server. + The client MUST treat this ID as an opaque value. + """ + + model_config = ConfigDict(extra="allow") + + +# Union type for elicitation request parameters +ElicitRequestParams: TypeAlias = ElicitRequestURLParams | ElicitRequestFormParams +"""Parameters for elicitation requests - either form or URL mode.""" + + class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): """A request from the server to elicit information from the client.""" @@ -1299,52 +1898,102 @@ class ElicitResult(Result): action: Literal["accept", "decline", "cancel"] """ The user action in response to the elicitation. - - "accept": User submitted the form/confirmed the action + - "accept": User submitted the form/confirmed the action (or consented to URL navigation) - "decline": User explicitly declined the action - "cancel": User dismissed without making an explicit choice """ - content: dict[str, str | int | float | bool | None] | None = None + content: dict[str, str | int | float | bool | list[str] | None] | None = None + """ + The submitted form data, only present when action is "accept" in form mode. + Contains values matching the requested schema. Values can be strings, integers, + booleans, or arrays of strings. + For URL mode, this field is omitted. """ - The submitted form data, only present when action is "accept". - Contains values matching the requested schema. + + +class ElicitationRequiredErrorData(BaseModel): + """Error data for URLElicitationRequiredError. + + Servers return this when a request cannot be processed until one or more + URL mode elicitations are completed. """ + elicitations: list[ElicitRequestURLParams] + """List of URL mode elicitations that must be completed.""" + + model_config = ConfigDict(extra="allow") -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): + +ClientResultType: TypeAlias = ( + EmptyResult + | CreateMessageResult + | CreateMessageResultWithTools + | ListRootsResult + | ElicitResult + | GetTaskResult + | GetTaskPayloadResult + | ListTasksResult + | CancelTaskResult + | CreateTaskResult +) + + +class ClientResult(RootModel[ClientResultType]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): +ServerRequestType: TypeAlias = ( + PingRequest + | CreateMessageRequest + | ListRootsRequest + | ElicitRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | CancelTaskRequest +) + + +class ServerRequest(RootModel[ServerRequestType]): pass -class ServerNotification( - RootModel[ - CancelledNotification - | ProgressNotification - | LoggingMessageNotification - | ResourceUpdatedNotification - | ResourceListChangedNotification - | ToolListChangedNotification - | PromptListChangedNotification - ] -): +ServerNotificationType: TypeAlias = ( + CancelledNotification + | ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + | PromptListChangedNotification + | ElicitCompleteNotification + | TaskStatusNotification +) + + +class ServerNotification(RootModel[ServerNotificationType]): pass -class ServerResult( - RootModel[ - EmptyResult - | InitializeResult - | CompleteResult - | GetPromptResult - | ListPromptsResult - | ListResourcesResult - | ListResourceTemplatesResult - | ReadResourceResult - | CallToolResult - | ListToolsResult - ] -): +ServerResultType: TypeAlias = ( + EmptyResult + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ListResourceTemplatesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + | GetTaskResult + | GetTaskPayloadResult + | ListTasksResult + | CancelTaskResult + | CreateTaskResult +) + + +class ServerResult(RootModel[ServerResultType]): pass diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index fb354ba7ff..44f4ab4d31 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -82,7 +82,7 @@ def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch): def fake_run(cmd: list[str], **kw: Any) -> subprocess.CompletedProcess[bytes]: if cmd[0] in candidates: return subprocess.CompletedProcess(cmd, 0) - else: + else: # pragma: no cover raise subprocess.CalledProcessError(1, cmd[0]) monkeypatch.setattr(sys, "platform", "win32") diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py new file mode 100644 index 0000000000..6d134af742 --- /dev/null +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -0,0 +1,431 @@ +import urllib.parse + +import jwt +import pytest +from pydantic import AnyHttpUrl, AnyUrl + +from mcp.client.auth.extensions.client_credentials import ( + ClientCredentialsOAuthProvider, + JWTParameters, + PrivateKeyJWTOAuthProvider, + RFC7523OAuthClientProvider, + SignedJWTParameters, + static_assertion_provider, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken + + +class MockTokenStorage: + """Mock token storage for testing.""" + + def __init__(self): + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: # pragma: no cover + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: # pragma: no cover + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: # pragma: no cover + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: # pragma: no cover + self._client_info = client_info + + +@pytest.fixture +def mock_storage(): + return MockTokenStorage() + + +@pytest.fixture +def client_metadata(): + return OAuthClientMetadata( + client_name="Test Client", + client_uri=AnyHttpUrl("https://example.com/"), + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + scope="read write", + ) + + +@pytest.fixture +def rfc7523_oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): + async def redirect_handler(url: str) -> None: # pragma: no cover + """Mock redirect handler.""" + pass + + async def callback_handler() -> tuple[str, str | None]: # pragma: no cover + """Mock callback handler.""" + return "test_auth_code", "test_state" + + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return RFC7523OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + +class TestOAuthFlowClientCredentials: + """Test OAuth flow behavior for client credentials flows.""" + + @pytest.mark.anyio + async def test_token_exchange_request_jwt_predefined(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): + """Test token exchange request building with a predefined JWT assertion.""" + # Set up required context + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + token_endpoint_auth_method="private_key_jwt", + redirect_uris=None, + scope="read write", + ) + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com/"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + registration_endpoint=AnyHttpUrl("https://api.example.com/register"), + ) + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" + rfc7523_oauth_provider.jwt_parameters = JWTParameters( + # https://www.jwt.io + assertion="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30" + ) + + request = await rfc7523_oauth_provider._exchange_token_jwt_bearer() + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + + # Check form data + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content + assert "scope=read write" in content + assert "resource=https://api.example.com/v1/mcp" in content + assert ( + "assertion=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30" + in content + ) + + @pytest.mark.anyio + async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523OAuthClientProvider): + """Test token exchange request building wiith a generated JWT assertion.""" + # Set up required context + rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull( + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + token_endpoint_auth_method="private_key_jwt", + redirect_uris=None, + scope="read write", + ) + rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com/"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + registration_endpoint=AnyHttpUrl("https://api.example.com/register"), + ) + rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info + rfc7523_oauth_provider.context.protocol_version = "2025-06-18" + rfc7523_oauth_provider.jwt_parameters = JWTParameters( + issuer="foo", + subject="1234567890", + claims={ + "name": "John Doe", + "admin": True, + "iat": 1516239022, + }, + jwt_signing_algorithm="HS256", + jwt_signing_key="a-string-secret-at-least-256-bits-long", + jwt_lifetime_seconds=300, + ) + + request = await rfc7523_oauth_provider._exchange_token_jwt_bearer() + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + + # Check form data + content = urllib.parse.unquote_plus(request.content.decode()).split("&") + assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content + assert "scope=read write" in content + assert "resource=https://api.example.com/v1/mcp" in content + + # Check assertion + assertion = next(param for param in content if param.startswith("assertion="))[len("assertion=") :] + claims = jwt.decode( + assertion, + key="a-string-secret-at-least-256-bits-long", + algorithms=["HS256"], + audience="https://api.example.com/", + subject="1234567890", + issuer="foo", + verify=True, + ) + assert claims["name"] == "John Doe" + assert claims["admin"] + assert claims["iat"] == 1516239022 + + +class TestClientCredentialsOAuthProvider: + """Test ClientCredentialsOAuthProvider.""" + + @pytest.mark.anyio + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + """Test that _initialize sets client_info.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + ) + + # client_info is set during _initialize + await provider._initialize() + + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "test-client-id" + assert provider.context.client_info.client_secret == "test-client-secret" + assert provider.context.client_info.grant_types == ["client_credentials"] + assert provider.context.client_info.token_endpoint_auth_method == "client_secret_basic" + + @pytest.mark.anyio + async def test_init_with_scopes(self, mock_storage: MockTokenStorage): + """Test that constructor accepts scopes.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + scopes="read write", + ) + + await provider._initialize() + assert provider.context.client_info is not None + assert provider.context.client_info.scope == "read write" + + @pytest.mark.anyio + async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): + """Test that constructor accepts client_secret_post auth method.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + ) + + await provider._initialize() + assert provider.context.client_info is not None + assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post" + + @pytest.mark.anyio + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): + """Test token exchange request building.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + scopes="read write", + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com/"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=read write" in content + assert "resource=https://api.example.com/v1/mcp" in content + + @pytest.mark.anyio + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): + """Test token exchange without scopes.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com/"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2024-11-05" # Old version - no resource param + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=" not in content + assert "resource=" not in content + + +class TestPrivateKeyJWTOAuthProvider: + """Test PrivateKeyJWTOAuthProvider.""" + + @pytest.mark.anyio + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + """Test that _initialize sets client_info.""" + + async def mock_assertion_provider(audience: str) -> str: # pragma: no cover + return "mock-jwt" + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/", + storage=mock_storage, + client_id="test-client-id", + assertion_provider=mock_assertion_provider, + ) + + # client_info is set during _initialize + await provider._initialize() + + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "test-client-id" + assert provider.context.client_info.grant_types == ["client_credentials"] + assert provider.context.client_info.token_endpoint_auth_method == "private_key_jwt" + + @pytest.mark.anyio + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): + """Test token exchange request building with assertion provider.""" + + async def mock_assertion_provider(audience: str) -> str: + return f"jwt-for-{audience}" + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + assertion_provider=mock_assertion_provider, + scopes="read write", + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + assert request.method == "POST" + assert str(request.url) == "https://auth.example.com/token" + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "client_assertion=jwt-for-https://auth.example.com/" in content + assert "client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer" in content + assert "scope=read write" in content + + @pytest.mark.anyio + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): + """Test token exchange without scopes.""" + + async def mock_assertion_provider(audience: str) -> str: + return f"jwt-for-{audience}" + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + assertion_provider=mock_assertion_provider, + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + provider.context.protocol_version = "2024-11-05" # Old version - no resource param + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=" not in content + assert "resource=" not in content + + +class TestSignedJWTParameters: + """Test SignedJWTParameters.""" + + @pytest.mark.anyio + async def test_create_assertion_provider(self): + """Test that create_assertion_provider creates valid JWTs.""" + params = SignedJWTParameters( + issuer="test-issuer", + subject="test-subject", + signing_key="a-string-secret-at-least-256-bits-long", + signing_algorithm="HS256", + lifetime_seconds=300, + ) + + provider = params.create_assertion_provider() + assertion = await provider("https://auth.example.com/") + + claims = jwt.decode( + assertion, + key="a-string-secret-at-least-256-bits-long", + algorithms=["HS256"], + audience="https://auth.example.com/", + ) + assert claims["iss"] == "test-issuer" + assert claims["sub"] == "test-subject" + assert claims["aud"] == "https://auth.example.com/" + assert "exp" in claims + assert "iat" in claims + assert "jti" in claims + + @pytest.mark.anyio + async def test_create_assertion_provider_with_additional_claims(self): + """Test that additional_claims are included in the JWT.""" + params = SignedJWTParameters( + issuer="test-issuer", + subject="test-subject", + signing_key="a-string-secret-at-least-256-bits-long", + signing_algorithm="HS256", + additional_claims={"custom": "value"}, + ) + + provider = params.create_assertion_provider() + assertion = await provider("https://auth.example.com/") + + claims = jwt.decode( + assertion, + key="a-string-secret-at-least-256-bits-long", + algorithms=["HS256"], + audience="https://auth.example.com/", + ) + assert claims["custom"] == "value" + + +class TestStaticAssertionProvider: + """Test static_assertion_provider helper.""" + + @pytest.mark.anyio + async def test_returns_static_token(self): + """Test that static_assertion_provider returns the same token regardless of audience.""" + token = "my-static-jwt-token" + provider = static_assertion_provider(token) + + result1 = await provider("https://auth1.example.com/") + result2 = await provider("https://auth2.example.com/") + + assert result1 == token + assert result2 == token diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 97014af9f0..1e5c4d524c 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -40,7 +40,7 @@ def clear(self) -> None: self.client.sent_messages.clear() self.server.sent_messages.clear() - def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: + def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover """Get client-sent requests, optionally filtered by method.""" return [ req.message.root @@ -48,15 +48,15 @@ def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest] if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] - def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: + def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover """Get server-sent requests, optionally filtered by method.""" - return [ + return [ # pragma: no cover req.message.root for req in self.server.sent_messages if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] - def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: + def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover """Get client-sent notifications, optionally filtered by method.""" return [ notif.message.root @@ -65,7 +65,7 @@ def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNot and (method is None or notif.message.root.method == method) ] - def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: + def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover """Get server-sent notifications, optionally filtered by method.""" return [ notif.message.root diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 6e58e496d3..6025ff811b 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2,8 +2,10 @@ Tests for refactored OAuth client authentication implementation. """ +import base64 import time from unittest import mock +from urllib.parse import unquote import httpx import pytest @@ -11,7 +13,27 @@ from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider, PKCEParameters -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata +from mcp.client.auth.utils import ( + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, + create_client_info_from_metadata_url, + create_client_registration_request, + create_oauth_metadata_request, + extract_field_from_www_auth, + extract_resource_metadata_from_www_auth, + extract_scope_from_www_auth, + get_client_metadata_scopes, + handle_registration_response, + is_valid_client_metadata_url, + should_use_client_metadata_url, +) +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) class MockTokenStorage: @@ -22,13 +44,13 @@ def __init__(self): self._client_info: OAuthClientInformationFull | None = None async def get_tokens(self) -> OAuthToken | None: - return self._tokens + return self._tokens # pragma: no cover async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info + return self._client_info # pragma: no cover async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info @@ -64,11 +86,11 @@ def valid_tokens(): def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: """Mock redirect handler.""" - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: """Mock callback handler.""" - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover return OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -79,6 +101,52 @@ async def callback_handler() -> tuple[str, str | None]: ) +@pytest.fixture +def prm_metadata_response(): + """PRM metadata response with scopes.""" + return httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", ' + b'"authorization_servers": ["https://auth.example.com/"], ' + b'"scopes_supported": ["resource:read", "resource:write"]}' + ), + ) + + +@pytest.fixture +def prm_metadata_without_scopes_response(): + """PRM metadata response without scopes.""" + return httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", ' + b'"authorization_servers": ["https://auth.example.com/"], ' + b'"scopes_supported": null}' + ), + ) + + +@pytest.fixture +def init_response_with_www_auth_scope(): + """Initial 401 response with WWW-Authenticate header containing scope.""" + return httpx.Response( + 401, + headers={"WWW-Authenticate": 'Bearer scope="special:scope from:www-authenticate"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + +@pytest.fixture +def init_response_without_www_auth_scope(): + """Initial 401 response without WWW-Authenticate scope.""" + return httpx.Response( + 401, + headers={}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + class TestPKCEParameters: """Test PKCE parameter generation.""" @@ -195,16 +263,16 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request( + async def test_build_protected_resource_discovery_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test protected resource discovery request building maintains backward compatibility.""" + """Test protected resource metadata discovery URL building with fallback.""" async def redirect_handler(url: str) -> None: - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com/", @@ -219,25 +287,28 @@ async def callback_handler() -> tuple[str, str | None]: status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com/") ) - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" - assert "mcp-protocol-version" in request.headers + urls = build_protected_resource_metadata_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) + assert len(urls) == 1 + assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource" # Test with WWW-Authenticate header init_response.headers["WWW-Authenticate"] = ( 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' ) - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" - assert "mcp-protocol-version" in request.headers + urls = build_protected_resource_metadata_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) + assert len(urls) == 2 + assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path" + assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): """Test OAuth metadata discovery request building.""" - request = oauth_provider._create_oauth_metadata_request("https://example.com/") + request = create_oauth_metadata_request("https://example.com/") # Ensure correct method and headers, and that the URL is unmodified assert request.method == "GET" @@ -248,14 +319,69 @@ def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider class TestOAuthFallback: """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" + @pytest.mark.anyio + async def test_oauth_discovery_legacy_fallback_when_no_prm(self): + """Test that when PRM discovery fails, only root OAuth URL is tried (March 2025 spec).""" + # When auth_server_url is None (PRM failed), we use server_url and only try root + discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://mcp.linear.app/sse") + + # Should only try the root URL (legacy behavior) + assert discovery_urls == [ + "https://mcp.linear.app/.well-known/oauth-authorization-server", + ] + + @pytest.mark.anyio + async def test_oauth_discovery_path_aware_when_auth_server_has_path(self): + """Test that when auth server URL has a path, only path-based URLs are tried.""" + discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + "https://auth.example.com/tenant1", "https://api.example.com/mcp" + ) + + # Should try path-based URLs only (no root URLs) + assert discovery_urls == [ + "https://auth.example.com/.well-known/oauth-authorization-server/tenant1", + "https://auth.example.com/.well-known/openid-configuration/tenant1", + "https://auth.example.com/tenant1/.well-known/openid-configuration", + ] + + @pytest.mark.anyio + async def test_oauth_discovery_root_when_auth_server_has_no_path(self): + """Test that when auth server URL has no path, only root URLs are tried.""" + discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + "https://auth.example.com/", "https://api.example.com/mcp" + ) + + # Should try root URLs only + assert discovery_urls == [ + "https://auth.example.com/.well-known/oauth-authorization-server", + "https://auth.example.com/.well-known/openid-configuration", + ] + + @pytest.mark.anyio + async def test_oauth_discovery_root_when_auth_server_has_only_slash(self): + """Test that when auth server URL has only trailing slash, treated as root.""" + discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + "https://auth.example.com/", "https://api.example.com/mcp" + ) + + # Should try root URLs only + assert discovery_urls == [ + "https://auth.example.com/.well-known/oauth-authorization-server", + "https://auth.example.com/.well-known/openid-configuration", + ] + @pytest.mark.anyio async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): - """Test fallback URL construction order.""" - discovery_urls = oauth_provider._get_discovery_urls() + """Test fallback URL construction order when auth server URL has a path.""" + # Simulate PRM discovery returning an auth server URL with a path + oauth_provider.context.auth_server_url = oauth_provider.context.server_url + + discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( + oauth_provider.context.auth_server_url, oauth_provider.context.server_url + ) assert discovery_urls == [ "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp", - "https://api.example.com/.well-known/oauth-authorization-server", "https://api.example.com/.well-known/openid-configuration/v1/mcp", "https://api.example.com/v1/mcp/.well-known/openid-configuration", ] @@ -299,13 +425,14 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl assert discovery_request.method == "GET" # Send a successful discovery response with minimal protected resource metadata + # Note: auth server URL has a path (/v1/mcp), so only path-based URLs will be tried discovery_response = httpx.Response( 200, content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/v1/mcp"]}', request=discovery_request, ) - # Next request should be to discover OAuth metadata + # Next request should be to discover OAuth metadata at path-aware OAuth URL oauth_metadata_request_1 = await auth_flow.asend(discovery_response) assert ( str(oauth_metadata_request_1.url) @@ -320,9 +447,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl request=oauth_metadata_request_1, ) - # Next request should be to discover OAuth metadata at the next endpoint + # Next request should be path-aware OIDC URL (not root URL since auth server has path) oauth_metadata_request_2 = await auth_flow.asend(oauth_metadata_response_1) - assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/oauth-authorization-server" + assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp" assert oauth_metadata_request_2.method == "GET" # Send a 400 response @@ -332,9 +459,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl request=oauth_metadata_request_2, ) - # Next request should be to discover OAuth metadata at the next endpoint + # Next request should be OIDC path-appended URL oauth_metadata_request_3 = await auth_flow.asend(oauth_metadata_response_2) - assert str(oauth_metadata_request_3.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp" + assert str(oauth_metadata_request_3.url) == "https://auth.example.com/v1/mcp/.well-known/openid-configuration" assert oauth_metadata_request_3.method == "GET" # Send a 500 response @@ -345,9 +472,12 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl ) # Mock the authorization process to minimize unnecessary state in this test - oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + oauth_provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) - # Next request should fall back to legacy behavior and auth with the RS (mocked /authorize, next is /token) + # All path-based URLs failed, flow continues with default endpoints + # Next request should be token exchange using MCP server base URL (fallback when OAuth metadata not found) token_request = await auth_flow.asend(oauth_metadata_response_3) assert str(token_request.url) == "https://api.example.com/token" assert token_request.method == "POST" @@ -392,40 +522,76 @@ async def test_handle_metadata_response_success(self, oauth_provider: OAuthClien assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" @pytest.mark.anyio - async def test_register_client_request(self, oauth_provider: OAuthClientProvider): - """Test client registration request building.""" - request = await oauth_provider._register_client() + async def test_prioritize_www_auth_scope_over_prm( + self, + oauth_provider: OAuthClientProvider, + prm_metadata_response: httpx.Response, + init_response_with_www_auth_scope: httpx.Response, + ): + """Test that WWW-Authenticate scope is prioritized over PRM scopes.""" + # First, process PRM metadata to set protected_resource_metadata with scopes + await oauth_provider._handle_protected_resource_response(prm_metadata_response) + + # Process the scope selection with WWW-Authenticate header + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_with_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) - assert request is not None - assert request.method == "POST" - assert str(request.url) == "https://api.example.com/register" - assert request.headers["Content-Type"] == "application/json" + # Verify that WWW-Authenticate scope is used (not PRM scopes) + assert scopes == "special:scope from:www-authenticate" @pytest.mark.anyio - async def test_register_client_skip_if_registered(self, oauth_provider: OAuthClientProvider): - """Test client registration is skipped if already registered.""" - # Set existing client info - client_info = OAuthClientInformationFull( - client_id="existing_client", - redirect_uris=[AnyUrl("http://localhost:3030/callback")], + async def test_prioritize_prm_scopes_when_no_www_auth_scope( + self, + oauth_provider: OAuthClientProvider, + prm_metadata_response: httpx.Response, + init_response_without_www_auth_scope: httpx.Response, + ): + """Test that PRM scopes are prioritized when WWW-Authenticate header has no scopes.""" + # Process the PRM metadata to set protected_resource_metadata with scopes + await oauth_provider._handle_protected_resource_response(prm_metadata_response) + + # Process the scope selection without WWW-Authenticate scope + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_without_www_auth_scope), + oauth_provider.context.protected_resource_metadata, ) - oauth_provider.context.client_info = client_info - # Should return None (skip registration) - request = await oauth_provider._register_client() - assert request is None + # Verify that PRM scopes are used + assert scopes == "resource:read resource:write" @pytest.mark.anyio - async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider): + async def test_omit_scope_when_no_prm_scopes_or_www_auth( + self, + oauth_provider: OAuthClientProvider, + prm_metadata_without_scopes_response: httpx.Response, + init_response_without_www_auth_scope: httpx.Response, + ): + """Test that scope is omitted when PRM has no scopes and WWW-Authenticate doesn't specify scope.""" + # Process the PRM metadata without scopes + await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) + + # Process the scope selection without WWW-Authenticate scope + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_without_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) + # Verify that scope is omitted + assert scopes is None + + @pytest.mark.anyio + async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider): """Test token exchange request building.""" # Set up required context oauth_provider.context.client_info = OAuthClientInformationFull( client_id="test_client", client_secret="test_secret", redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", ) - request = await oauth_provider._exchange_token("test_auth_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier") assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" @@ -448,6 +614,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, client_id="test_client", client_secret="test_secret", redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_post", ) request = await oauth_provider._refresh_token() @@ -463,6 +630,114 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert "client_id=test_client" in content assert "client_secret=test_secret" in content + @pytest.mark.anyio + async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider): + """Test token exchange with client_secret_basic authentication.""" + # Set up OAuth metadata to support basic auth + oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + token_endpoint_auth_methods_supported=["client_secret_basic", "client_secret_post"], + ) + + client_id_raw = "test@client" # Include special character to test URL encoding + client_secret_raw = "test:secret" # Include colon to test URL encoding + + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id=client_id_raw, + client_secret=client_secret_raw, + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_basic", + ) + + request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier") + + # Should use basic auth (registered method) + assert "Authorization" in request.headers + assert request.headers["Authorization"].startswith("Basic ") + + # Decode and verify credentials are properly URL-encoded + encoded_creds = request.headers["Authorization"][6:] # Remove "Basic " prefix + decoded = base64.b64decode(encoded_creds).decode() + client_id, client_secret = decoded.split(":", 1) + + # Check URL encoding was applied + assert client_id == "test%40client" # @ should be encoded as %40 + assert client_secret == "test%3Asecret" # : should be encoded as %3A + + # Verify decoded values match original + assert unquote(client_id) == client_id_raw + assert unquote(client_secret) == client_secret_raw + + # client_secret should NOT be in body for basic auth + content = request.content.decode() + assert "client_secret=" not in content + assert "client_id=test%40client" in content # client_id still in body + + @pytest.mark.anyio + async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + """Test token refresh with client_secret_basic authentication.""" + oauth_provider.context.current_tokens = valid_tokens + + # Set up OAuth metadata to only support basic auth + oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + + client_id = "test_client" + client_secret = "test_secret" + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id=client_id, + client_secret=client_secret, + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="client_secret_basic", + ) + + request = await oauth_provider._refresh_token() + + assert "Authorization" in request.headers + assert request.headers["Authorization"].startswith("Basic ") + + encoded_creds = request.headers["Authorization"][6:] + decoded = base64.b64decode(encoded_creds).decode() + assert decoded == f"{client_id}:{client_secret}" + + # client_secret should NOT be in body + content = request.content.decode() + assert "client_secret=" not in content + + @pytest.mark.anyio + async def test_none_auth_method(self, oauth_provider: OAuthClientProvider): + """Test 'none' authentication method (public client).""" + oauth_provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + token_endpoint_auth_methods_supported=["none"], + ) + + client_id = "public_client" + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id=client_id, + client_secret=None, # No secret for public client + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + token_endpoint_auth_method="none", + ) + + request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier") + + # Should NOT have Authorization header + assert "Authorization" not in request.headers + + # Should NOT have client_secret in body + content = request.content.decode() + assert "client_secret=" not in content + assert "client_id=public_client" in content + class TestProtectedResourceMetadata: """Test protected resource handling.""" @@ -479,7 +754,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ ) # Test in token exchange - request = await oauth_provider._exchange_token("test_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" in content # Check URL-encoded resource parameter @@ -510,7 +785,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro ) # Test in token exchange - request = await oauth_provider._exchange_token("test_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" not in content @@ -540,7 +815,7 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa ) # Test in token exchange - request = await oauth_provider._exchange_token("test_code", "test_verifier") + request = await oauth_provider._exchange_token_authorization_code("test_code", "test_verifier") content = request.content.decode() assert "resource=" in content @@ -549,7 +824,7 @@ class TestRegistrationResponse: """Test client registration response handling.""" @pytest.mark.anyio - async def test_handle_registration_response_reads_before_accessing_text(self, oauth_provider: OAuthClientProvider): + async def test_handle_registration_response_reads_before_accessing_text(self): """Test that response.aread() is called before accessing response.text.""" # Track if aread() was called @@ -566,14 +841,14 @@ async def aread(self): @property def text(self): if not self._aread_called: - raise RuntimeError("Response.text accessed before response.aread()") + raise RuntimeError("Response.text accessed before response.aread()") # pragma: no cover return self._text mock_response = MockResponse() # This should call aread() before accessing text with pytest.raises(Exception) as exc_info: - await oauth_provider._handle_registration_response(mock_response) + await handle_registration_response(mock_response) # Verify aread() was called assert mock_response._aread_called @@ -581,6 +856,49 @@ def text(self): assert "Registration failed: 400" in str(exc_info.value) +class TestCreateClientRegistrationRequest: + """Test client registration request creation.""" + + def test_uses_registration_endpoint_from_metadata(self): + """Test that registration URL comes from metadata when available.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), + ) + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(oauth_metadata, client_metadata, "https://auth.example.com/") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + def test_falls_back_to_default_register_endpoint_when_no_metadata(self): + """Test that registration uses fallback URL when auth_server_metadata is None.""" + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(None, client_metadata, "https://auth.example.com/") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + def test_falls_back_when_metadata_has_no_registration_endpoint(self): + """Test fallback when metadata exists but lacks registration_endpoint.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + # No registration_endpoint + ) + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(oauth_metadata, client_metadata, "https://auth.example.com/") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + class TestAuthFlow: """Test the auth flow in httpx.""" @@ -613,7 +931,7 @@ async def test_auth_flow_with_valid_tokens( pass # Expected @pytest.mark.anyio - async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvider): + async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage): """Test auth flow when no tokens are available, triggering the full OAuth flow.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -682,7 +1000,9 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide ) # Mock the authorization process - oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + oauth_provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) # Next request should be to exchange token token_request = await auth_flow.asend(registration_response) @@ -747,13 +1067,13 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # In the fixed version, this should end the generator try: await auth_flow.asend(response) # extra request - request_yields += 1 + request_yields += 1 # pragma: no cover # If we reach here, the bug is present pytest.fail( f"Unnecessary retry detected! Request was yielded {request_yields} times. " f"This indicates the retry logic bug that caused 2x performance degradation. " f"The request should only be yielded once for successful responses." - ) + ) # pragma: no cover except StopAsyncIteration: # This is the expected behavior - no unnecessary retry pass @@ -761,131 +1081,642 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # Verify exactly one request was yielded (no double-sending) assert request_yields == 1, f"Expected 1 request yield, got {request_yields}" + @pytest.mark.anyio + async def test_token_exchange_accepts_201_status( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage + ): + """Test that token exchange accepts both 200 and 201 status codes.""" + # Ensure no tokens are stored + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True -@pytest.mark.parametrize( - ( - "issuer_url", - "service_documentation_url", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - "revocation_endpoint", - ), - ( - # Pydantic's AnyUrl incorrectly adds trailing slash to base URLs - # This is being fixed in https://github.com/pydantic/pydantic-core/pull/1719 (Pydantic 2.12+) - pytest.param( - "https://auth.example.com/", - "https://auth.example.com/docs", - "https://auth.example.com/authorize", - "https://auth.example.com/token", - "https://auth.example.com/register", - "https://auth.example.com/revoke", - id="simple-url", - marks=pytest.mark.xfail( - reason="Pydantic AnyUrl adds trailing slash to base URLs - fixed in Pydantic 2.12+" - ), - ), - pytest.param( - "https://auth.example.com/", - "https://auth.example.com/docs", - "https://auth.example.com/authorize", - "https://auth.example.com/token", - "https://auth.example.com/register", - "https://auth.example.com/revoke", - id="with-trailing-slash", - ), - pytest.param( - "https://auth.example.com/v1/mcp", - "https://auth.example.com/v1/mcp/docs", - "https://auth.example.com/v1/mcp/authorize", - "https://auth.example.com/v1/mcp/token", - "https://auth.example.com/v1/mcp/register", - "https://auth.example.com/v1/mcp/revoke", - id="with-path-param", - ), - ), -) -def test_build_metadata( - issuer_url: str, - service_documentation_url: str, - authorization_endpoint: str, - token_endpoint: str, - registration_endpoint: str, - revocation_endpoint: str, -): - from mcp.server.auth.routes import build_metadata - from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions + # Create a test request + test_request = httpx.Request("GET", "https://api.example.com/mcp") - metadata = build_metadata( - issuer_url=AnyHttpUrl(issuer_url), - service_documentation_url=AnyHttpUrl(service_documentation_url), - client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]), - revocation_options=RevocationOptions(enabled=True), - ) + # Mock the auth flow + auth_flow = oauth_provider.async_auth_flow(test_request) - assert metadata.model_dump(exclude_defaults=True, mode="json") == snapshot( - { - "issuer": Is(issuer_url), - "authorization_endpoint": Is(authorization_endpoint), - "token_endpoint": Is(token_endpoint), - "registration_endpoint": Is(registration_endpoint), - "scopes_supported": ["read", "write", "admin"], - "grant_types_supported": ["authorization_code", "refresh_token"], - "token_endpoint_auth_methods_supported": ["client_secret_post"], - "service_documentation": Is(service_documentation_url), - "revocation_endpoint": Is(revocation_endpoint), - "revocation_endpoint_auth_methods_supported": ["client_secret_post"], - "code_challenge_methods_supported": ["S256"], - } - ) + # First request should be the original request without auth header + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send a 401 response to trigger the OAuth flow + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + # Next request should be to discover protected resource metadata + discovery_request = await auth_flow.asend(response) + assert discovery_request.method == "GET" + assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" -class TestProtectedResourceWWWAuthenticate: - """Test RFC9728 WWW-Authenticate header parsing functionality for protected resource.""" + # Send a successful discovery response with minimal protected resource metadata + discovery_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com/"]}', + request=discovery_request, + ) - @pytest.mark.parametrize( - "www_auth_header,expected_url", - [ - # Quoted URL - ( - 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', - "https://api.example.com/.well-known/oauth-protected-resource", + # Next request should be to discover OAuth metadata + oauth_metadata_request = await auth_flow.asend(discovery_response) + assert oauth_metadata_request.method == "GET" + assert str(oauth_metadata_request.url).startswith("https://auth.example.com/") + assert "mcp-protocol-version" in oauth_metadata_request.headers + + # Send a successful OAuth metadata response + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com/", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' ), - # Unquoted URL - ( - "Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", - "https://api.example.com/.well-known/oauth-protected-resource", + request=oauth_metadata_request, + ) + + # Next request should be to register client + registration_request = await auth_flow.asend(oauth_metadata_response) + assert registration_request.method == "POST" + assert str(registration_request.url) == "https://auth.example.com/register" + + # Send a successful registration response with 201 status + registration_response = httpx.Response( + 201, + content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}', + request=registration_request, + ) + + # Mock the authorization process + oauth_provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Next request should be to exchange token + token_request = await auth_flow.asend(registration_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + assert "code=test_auth_code" in token_request.content.decode() + + # Send a successful token response with 201 status code (test both 200 and 201 are accepted) + token_response = httpx.Response( + 201, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, ' + b'"refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + + # Final request should be the original request with auth header + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer new_access_token" + assert final_request.method == "GET" + assert str(final_request.url) == "https://api.example.com/mcp" + + # Send final success response to properly close the generator + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass # Expected - generator should complete + + # Verify tokens were stored + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "new_access_token" + assert oauth_provider.context.token_expiry_time is not None + + @pytest.mark.anyio + async def test_403_insufficient_scope_updates_scope_from_header( + self, + oauth_provider: OAuthClientProvider, + mock_storage: MockTokenStorage, + valid_tokens: OAuthToken, + ): + """Test that 403 response correctly updates scope from WWW-Authenticate header.""" + # Pre-store valid tokens and client info + client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info(client_info) + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() + 1800 + oauth_provider.context.client_info = client_info + oauth_provider._initialized = True + + # Original scope + assert oauth_provider.context.client_metadata.scope == "read write" + + redirect_captured = False + captured_state = None + + async def capture_redirect(url: str) -> None: + nonlocal redirect_captured, captured_state + redirect_captured = True + # Verify the new scope is included in authorization URL + assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace( + "%3A", ":" + ).replace("+", " ") + # Extract state from redirect URL + from urllib.parse import parse_qs, urlparse + + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + oauth_provider.context.redirect_handler = capture_redirect + + # Mock callback + async def mock_callback() -> tuple[str, str | None]: + return "auth_code", captured_state + + oauth_provider.context.callback_handler = mock_callback + + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + + # Send 403 with new scope requirement + response_403 = httpx.Response( + 403, + headers={"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="admin:write admin:delete"'}, + request=request, + ) + + # Trigger step-up - should get token exchange request + token_exchange_request = await auth_flow.asend(response_403) + + # Verify scope was updated + assert oauth_provider.context.client_metadata.scope == "admin:write admin:delete" + assert redirect_captured + + # Complete the flow with successful token response + token_response = httpx.Response( + 200, + json={ + "access_token": "new_token_with_new_scope", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "admin:write admin:delete", + }, + request=token_exchange_request, + ) + + # Should get final retry request + final_request = await auth_flow.asend(token_response) + + # Send success response - flow should complete + success_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(success_response) + pytest.fail("Should have stopped after successful response") # pragma: no cover + except StopAsyncIteration: + pass # Expected + + +@pytest.mark.parametrize( + ( + "issuer_url", + "service_documentation_url", + "authorization_endpoint", + "token_endpoint", + "registration_endpoint", + "revocation_endpoint", + ), + ( + # Pydantic's AnyUrl incorrectly adds trailing slash to base URLs + # This is being fixed in https://github.com/pydantic/pydantic-core/pull/1719 (Pydantic 2.12+) + pytest.param( + "https://auth.example.com/", + "https://auth.example.com/docs", + "https://auth.example.com/authorize", + "https://auth.example.com/token", + "https://auth.example.com/register", + "https://auth.example.com/revoke", + id="simple-url", + marks=pytest.mark.xfail( + reason="Pydantic AnyUrl adds trailing slash to base URLs - fixed in Pydantic 2.12+" + ), + ), + pytest.param( + "https://auth.example.com/", + "https://auth.example.com/docs", + "https://auth.example.com/authorize", + "https://auth.example.com/token", + "https://auth.example.com/register", + "https://auth.example.com/revoke", + id="with-trailing-slash", + ), + pytest.param( + "https://auth.example.com/v1/mcp", + "https://auth.example.com/v1/mcp/docs", + "https://auth.example.com/v1/mcp/authorize", + "https://auth.example.com/v1/mcp/token", + "https://auth.example.com/v1/mcp/register", + "https://auth.example.com/v1/mcp/revoke", + id="with-path-param", + ), + ), +) +def test_build_metadata( + issuer_url: str, + service_documentation_url: str, + authorization_endpoint: str, + token_endpoint: str, + registration_endpoint: str, + revocation_endpoint: str, +): + from mcp.server.auth.routes import build_metadata + from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions + + metadata = build_metadata( + issuer_url=AnyHttpUrl(issuer_url), + service_documentation_url=AnyHttpUrl(service_documentation_url), + client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]), + revocation_options=RevocationOptions(enabled=True), + ) + + assert metadata.model_dump(exclude_defaults=True, mode="json") == snapshot( + { + "issuer": Is(issuer_url), + "authorization_endpoint": Is(authorization_endpoint), + "token_endpoint": Is(token_endpoint), + "registration_endpoint": Is(registration_endpoint), + "scopes_supported": ["read", "write", "admin"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "service_documentation": Is(service_documentation_url), + "revocation_endpoint": Is(revocation_endpoint), + "revocation_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "code_challenge_methods_supported": ["S256"], + } + ) + + +class TestLegacyServerFallback: + """Test backward compatibility with legacy servers that don't support PRM (issue #1495).""" + + @pytest.mark.anyio + async def test_legacy_server_no_prm_falls_back_to_root_oauth_discovery( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that when PRM discovery fails completely, we fall back to root OAuth discovery (March 2025 spec).""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + # Simulate a legacy server like Linear + provider = OAuthClientProvider( + server_url="https://mcp.linear.app/sse", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Mock client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://mcp.linear.app/sse") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 without WWW-Authenticate header (typical legacy server) + response = httpx.Response(401, headers={}, request=test_request) + + # Should try path-based PRM first + prm_request_1 = await auth_flow.asend(response) + assert str(prm_request_1.url) == "https://mcp.linear.app/.well-known/oauth-protected-resource/sse" + + # PRM returns 404 + prm_response_1 = httpx.Response(404, request=prm_request_1) + + # Should try root-based PRM + prm_request_2 = await auth_flow.asend(prm_response_1) + assert str(prm_request_2.url) == "https://mcp.linear.app/.well-known/oauth-protected-resource" + + # PRM returns 404 again - all PRM URLs failed + prm_response_2 = httpx.Response(404, request=prm_request_2) + + # Should fall back to root OAuth discovery (March 2025 spec behavior) + oauth_metadata_request = await auth_flow.asend(prm_response_2) + assert str(oauth_metadata_request.url) == "https://mcp.linear.app/.well-known/oauth-authorization-server" + assert oauth_metadata_request.method == "GET" + + # Send successful OAuth metadata response + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://mcp.linear.app/", ' + b'"authorization_endpoint": "https://mcp.linear.app/authorize", ' + b'"token_endpoint": "https://mcp.linear.app/token"}' + ), + request=oauth_metadata_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Next should be token exchange + token_request = await auth_flow.asend(oauth_metadata_response) + assert str(token_request.url) == "https://mcp.linear.app/token" + + # Send successful token response + token_response = httpx.Response( + 200, + content=b'{"access_token": "linear_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + # Final request with auth header + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer linear_token" + assert str(final_request.url) == "https://mcp.linear.app/sse" + + # Complete flow + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_legacy_server_with_different_prm_and_root_urls( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test PRM fallback with different WWW-Authenticate and root URLs.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + await auth_flow.__anext__() + + # 401 with custom WWW-Authenticate PRM URL + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://custom.prm.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + + # Try custom PRM URL first + prm_request_1 = await auth_flow.asend(response) + assert str(prm_request_1.url) == "https://custom.prm.com/.well-known/oauth-protected-resource" + + # Returns 500 + prm_response_1 = httpx.Response(500, request=prm_request_1) + + # Try path-based fallback + prm_request_2 = await auth_flow.asend(prm_response_1) + assert str(prm_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + + # Returns 404 + prm_response_2 = httpx.Response(404, request=prm_request_2) + + # Try root fallback + prm_request_3 = await auth_flow.asend(prm_response_2) + assert str(prm_request_3.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Also returns 404 - all PRM URLs failed + prm_response_3 = httpx.Response(404, request=prm_request_3) + + # Should fall back to root OAuth discovery + oauth_metadata_request = await auth_flow.asend(prm_response_3) + assert str(oauth_metadata_request.url) == "https://api.example.com/.well-known/oauth-authorization-server" + + # Complete the flow + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://api.example.com/", ' + b'"authorization_endpoint": "https://api.example.com/authorize", ' + b'"token_endpoint": "https://api.example.com/token"}' + ), + request=oauth_metadata_request, + ) + + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + token_request = await auth_flow.asend(oauth_metadata_response) + assert str(token_request.url) == "https://api.example.com/token" + + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer test_token" + + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + +class TestSEP985Discovery: + """Test SEP-985 protected resource metadata discovery with fallback.""" + + @pytest.mark.anyio + async def test_path_based_fallback_when_no_www_authenticate( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Test with 401 response without WWW-Authenticate header + init_response = httpx.Response( + status_code=401, headers={}, request=httpx.Request("GET", "https://api.example.com/v1/mcp") + ) + + # Build discovery URLs + discovery_urls = build_protected_resource_metadata_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) + + # Should have path-based URL first, then root-based URL + assert len(discovery_urls) == 2 + assert discovery_urls[0] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" + + @pytest.mark.anyio + async def test_root_based_fallback_after_path_based_404( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that client falls back to root-based URI when path-based returns 404.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Ensure no tokens are stored + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Mock client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Create a test request + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + + # Mock the auth flow + auth_flow = provider.async_auth_flow(test_request) + + # First request should be the original request without auth header + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send a 401 response without WWW-Authenticate header + response = httpx.Response(401, headers={}, request=test_request) + + # Next request should be to discover protected resource metadata (path-based) + discovery_request_1 = await auth_flow.asend(response) + assert str(discovery_request_1.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert discovery_request_1.method == "GET" + + # Send 404 response for path-based discovery + discovery_response_1 = httpx.Response(404, request=discovery_request_1) + + # Next request should be to root-based well-known URI + discovery_request_2 = await auth_flow.asend(discovery_response_1) + assert str(discovery_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource" + assert discovery_request_2.method == "GET" + + # Send successful discovery response + discovery_response_2 = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/"]}' ), - # Complex header with multiple parameters - ( - 'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", ' - 'error="insufficient_scope"', - "https://api.example.com/.well-known/oauth-protected-resource", + request=discovery_request_2, + ) + + # Mock the rest of the OAuth flow + provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + + # Next should be OAuth metadata discovery + oauth_metadata_request = await auth_flow.asend(discovery_response_2) + assert oauth_metadata_request.method == "GET" + + # Complete the flow + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com/", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token"}' ), - # Different URL format - ('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"), - # With path and query params - ( - 'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', - "https://api.example.com/auth/metadata?version=1", + request=oauth_metadata_request, + ) + + token_request = await auth_flow.asend(oauth_metadata_response) + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, ' + b'"refresh_token": "new_refresh_token"}' ), - ], - ) - def test_extract_resource_metadata_from_www_auth_valid_cases( - self, - client_metadata: OAuthClientMetadata, - mock_storage: MockTokenStorage, - www_auth_header: str, - expected_url: str, + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: # pragma: no cover + pass + + @pytest.mark.anyio + async def test_www_authenticate_takes_priority_over_well_known( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test extraction of resource_metadata URL from various valid WWW-Authenticate headers.""" + """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" async def redirect_handler(url: str) -> None: - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -895,54 +1726,346 @@ async def callback_handler() -> tuple[str, str | None]: callback_handler=callback_handler, ) + # Test with 401 response with WWW-Authenticate header init_response = httpx.Response( status_code=401, - headers={"WWW-Authenticate": www_auth_header}, - request=httpx.Request("GET", "https://api.example.com/test"), + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"' + }, + request=httpx.Request("GET", "https://api.example.com/v1/mcp"), + ) + + # Build discovery URLs + discovery_urls = build_protected_resource_metadata_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url ) - result = provider._extract_resource_metadata_from_www_auth(init_response) - assert result == expected_url + # Should have WWW-Authenticate URL first, then fallback URLs + assert len(discovery_urls) == 3 + assert discovery_urls[0] == "https://custom.example.com/.well-known/oauth-protected-resource" + assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert discovery_urls[2] == "https://api.example.com/.well-known/oauth-protected-resource" + + +class TestWWWAuthenticate: + """Test WWW-Authenticate header parsing functionality.""" @pytest.mark.parametrize( - "status_code,www_auth_header,description", + "www_auth_header,field_name,expected_value", [ - # No header - (401, None, "no WWW-Authenticate header"), - # Empty header - (401, "", "empty WWW-Authenticate header"), - # Header without resource_metadata - (401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"), - # Malformed header - (401, "Bearer resource_metadata=", "malformed resource_metadata parameter"), - # Non-401 status code + # Quoted values + ('Bearer scope="read write"', "scope", "read write"), ( - 200, 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', - "200 OK response", + "resource_metadata", + "https://api.example.com/.well-known/oauth-protected-resource", ), + ('Bearer error="insufficient_scope"', "error", "insufficient_scope"), + # Unquoted values + ("Bearer scope=read", "scope", "read"), ( - 500, - 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', - "500 error response", + "Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", + "resource_metadata", + "https://api.example.com/.well-known/oauth-protected-resource", + ), + ("Bearer error=invalid_token", "error", "invalid_token"), + # Multiple parameters with quoted value + ( + 'Bearer realm="api", scope="admin:write resource:read", error="insufficient_scope"', + "scope", + "admin:write resource:read", + ), + ( + 'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", ' + 'error="insufficient_scope"', + "resource_metadata", + "https://api.example.com/.well-known/oauth-protected-resource", + ), + # Multiple parameters with unquoted value + ('Bearer realm="api", scope=basic', "scope", "basic"), + # Values with special characters + ( + 'Bearer scope="resource:read resource:write user_profile"', + "scope", + "resource:read resource:write user_profile", + ), + ( + 'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', + "resource_metadata", + "https://api.example.com/auth/metadata?version=1", ), ], ) - def test_extract_resource_metadata_from_www_auth_invalid_cases( + def test_extract_field_from_www_auth_valid_cases( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + www_auth_header: str, + field_name: str, + expected_value: str, + ): + """Test extraction of various fields from valid WWW-Authenticate headers.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": www_auth_header}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, field_name) + assert result == expected_value + + @pytest.mark.parametrize( + "www_auth_header,field_name,description", + [ + # No header + (None, "scope", "no WWW-Authenticate header"), + # Empty header + ("", "scope", "empty WWW-Authenticate header"), + # Header without requested field + ('Bearer realm="api", error="insufficient_scope"', "scope", "no scope parameter"), + ('Bearer realm="api", scope="read write"', "resource_metadata", "no resource_metadata parameter"), + # Malformed field (empty value) + ("Bearer scope=", "scope", "malformed scope parameter"), + ("Bearer resource_metadata=", "resource_metadata", "malformed resource_metadata parameter"), + ], + ) + def test_extract_field_from_www_auth_invalid_cases( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage, - status_code: int, www_auth_header: str | None, + field_name: str, description: str, ): """Test extraction returns None for invalid cases.""" + headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} + init_response = httpx.Response( + status_code=401, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") + ) + + result = extract_field_from_www_auth(init_response, field_name) + assert result is None, f"Should return None for {description}" + + +class TestCIMD: + """Test Client ID Metadata Document (CIMD) support.""" + + @pytest.mark.parametrize( + "url,expected", + [ + # Valid CIMD URLs + ("https://example.com/client", True), + ("https://example.com/client-metadata.json", True), + ("https://example.com/path/to/client", True), + ("https://example.com:8443/client", True), + # Invalid URLs - HTTP (not HTTPS) + ("http://example.com/client", False), + # Invalid URLs - root path + ("https://example.com/", False), + ("https://example.com/", False), + # Invalid URLs - None or empty + (None, False), + ("", False), + # Invalid URLs - malformed (triggers urlparse exception) + ("http://[::1/foo/", False), + ], + ) + def test_is_valid_client_metadata_url(http://github.com/self,%20url:%20str%20|%20None,%20expected:%20bool): + """Test CIMD URL validation.""" + assert is_valid_client_metadata_http://github.com/url(url) == expected + + def test_should_use_client_metadata_url_when_server_supports(self): + """Test that CIMD is used when server supports it and URL is provided.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=True, + ) + assert should_use_client_metadata_url(http://github.com/oauth_metadata,%20%22https://example.com/client") is True + + def test_should_not_use_client_metadata_url_when_server_does_not_support(self): + """Test that CIMD is not used when server doesn't support it.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=False, + ) + assert should_use_client_metadata_url(http://github.com/oauth_metadata,%20%22https://example.com/client") is False + + def test_should_not_use_client_metadata_url_when_not_provided(self): + """Test that CIMD is not used when no URL is provided.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com/"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=True, + ) + assert should_use_client_metadata_url(http://github.com/oauth_metadata,%20None) is False + + def test_should_not_use_client_metadata_url_when_no_metadata(self): + """Test that CIMD is not used when OAuth metadata is None.""" + assert should_use_client_metadata_url(http://github.com/None,%20%22https://example.com/client") is False + + def test_create_client_info_from_metadata_url(http://github.com/self): + """Test creating client info from CIMD URL.""" + client_info = create_client_info_from_metadata_url( + "https://example.com/client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + assert client_info.client_id == "https://example.com/client" + assert client_info.token_endpoint_auth_method == "none" + assert client_info.redirect_uris == [AnyUrl("http://localhost:3030/callback")] + assert client_info.client_secret is None + + def test_oauth_provider_with_valid_client_metadata_url( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test OAuthClientProvider initialization with valid client_metadata_url.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + assert provider.context.client_metadata_url == "https://example.com/client" + + def test_oauth_provider_with_invalid_client_metadata_url_raises_error( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test OAuthClientProvider raises error for invalid client_metadata_url.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + with pytest.raises(ValueError) as exc_info: + OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="http://example.com/client", # HTTP instead of HTTPS + ) + assert "HTTPS URL with a non-root pathname" in str(exc_info.value) + + @pytest.mark.anyio + async def test_auth_flow_uses_cimd_when_server_supports( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that auth flow uses CIMD URL as client_id when server supports it.""" + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 response + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/"]}', + request=prm_request, + ) + + # OAuth metadata discovery + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com/", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"client_id_metadata_document_supported": true}' + ), + request=oauth_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Should skip DCR and go directly to token exchange + token_request = await auth_flow.asend(oauth_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + + # Verify client_id is the CIMD URL + content = token_request.content.decode() + assert "client_id=https%3A%2F%2Fexample.com%2Fclient" in content + + # Verify client info was set correctly + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "https://example.com/client" + assert provider.context.client_info.token_endpoint_auth_method == "none" + + # Complete the flow + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer test_token" + + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: pass + @pytest.mark.anyio + async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that auth flow falls back to DCR when server doesn't support CIMD.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -950,12 +2073,70 @@ async def callback_handler() -> tuple[str, str | None]: storage=mock_storage, redirect_handler=redirect_handler, callback_handler=callback_handler, + client_metadata_url="https://example.com/client", ) - headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} - init_response = httpx.Response( - status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + await auth_flow.__anext__() + + # Send 401 response + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/"]}', + request=prm_request, ) - result = provider._extract_resource_metadata_from_www_auth(init_response) - assert result is None, f"Should return None for {description}" + # OAuth metadata discovery - server does NOT support CIMD + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com/", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=oauth_request, + ) + + # Should proceed to DCR instead of skipping it + registration_request = await auth_flow.asend(oauth_response) + assert registration_request.method == "POST" + assert str(registration_request.url) == "https://auth.example.com/register" + + # Complete the flow to avoid generator cleanup issues + registration_response = httpx.Response( + 201, + content=b'{"client_id": "dcr_client_id", "redirect_uris": ["http://localhost:3030/callback"]}', + request=registration_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + token_request = await auth_flow.asend(registration_response) + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass diff --git a/tests/client/test_config.py b/tests/client/test_config.py index f144dcffb9..d1a0576ff3 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -44,7 +44,7 @@ def test_command_execution(mock_config_path: Path): test_args = [command] + args + ["--help"] - result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False) + result = subprocess.run(test_args, capture_output=True, text=True, timeout=20, check=False) assert result.returncode == 0 assert "usage" in result.stdout.lower() diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index edf8675e56..ec38f35838 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -7,13 +7,13 @@ import multiprocessing import socket -import time from collections.abc import Generator import pytest from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client +from tests.test_helpers import wait_for_server # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -35,7 +35,7 @@ } -def run_unicode_server(port: int) -> None: +def run_unicode_server(port: int) -> None: # pragma: no cover """Run the Unicode test server in a separate process.""" # Import inside the function since this runs in a separate process from collections.abc import AsyncGenerator @@ -158,19 +158,8 @@ def running_unicode_server(unicode_server_port: int) -> Generator[str, None, Non proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) proc.start() - # Wait for server to be running - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", unicode_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + # Wait for server to be ready + wait_for_server(unicode_server_port) try: yield f"http://127.0.0.1:{unicode_server_port}" @@ -178,7 +167,7 @@ def running_unicode_server(unicode_server_port: int) -> Generator[str, None, Non # Clean up - try graceful termination first proc.terminate() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover proc.kill() proc.join(timeout=1) @@ -189,7 +178,7 @@ async def test_streamable_http_client_unicode_tool_call(running_unicode_server: base_url = running_unicode_server endpoint_url = f"{base_url}/mcp" - async with streamablehttp_client(endpoint_url) as (read_stream, write_stream, _get_session_id): + async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() @@ -221,7 +210,7 @@ async def test_streamable_http_client_unicode_prompts(running_unicode_server: st base_url = running_unicode_server endpoint_url = f"{base_url}/mcp" - async with streamablehttp_client(endpoint_url) as (read_stream, write_stream, _get_session_id): + async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index b31b704a40..94a72c34e2 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -2,213 +2,227 @@ import pytest +import mcp.types as types +from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.memory import create_connected_server_and_client_session as create_session +from mcp.types import ListToolsRequest, ListToolsResult from .conftest import StreamSpyCollection pytestmark = pytest.mark.anyio -async def test_list_tools_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): - """Test that the cursor parameter is accepted for list_tools - and that it is correctly passed to the server. - - See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format - """ +@pytest.fixture +async def full_featured_server(): + """Create a server with tools, resources, prompts, and templates.""" server = FastMCP("test") - # Create a couple of test tools @server.tool(name="test_tool_1") - async def test_tool_1() -> str: + async def test_tool_1() -> str: # pragma: no cover """First test tool""" return "Result 1" @server.tool(name="test_tool_2") - async def test_tool_2() -> str: + async def test_tool_2() -> str: # pragma: no cover """Second test tool""" return "Result 2" - async with create_session(server._mcp_server) as client_session: - spies = stream_spy() - - # Test without cursor parameter (omitted) - _ = await client_session.list_tools() - list_tools_requests = spies.get_client_requests(method="tools/list") - assert len(list_tools_requests) == 1 - assert list_tools_requests[0].params is None - - spies.clear() - - # Test with cursor=None - _ = await client_session.list_tools(cursor=None) - list_tools_requests = spies.get_client_requests(method="tools/list") - assert len(list_tools_requests) == 1 - assert list_tools_requests[0].params is None - - spies.clear() - - # Test with cursor as string - _ = await client_session.list_tools(cursor="some_cursor_value") - list_tools_requests = spies.get_client_requests(method="tools/list") - assert len(list_tools_requests) == 1 - assert list_tools_requests[0].params is not None - assert list_tools_requests[0].params["cursor"] == "some_cursor_value" - - spies.clear() + @server.resource("resource://test/data") + async def test_resource() -> str: # pragma: no cover + """Test resource""" + return "Test data" - # Test with empty string cursor - _ = await client_session.list_tools(cursor="") - list_tools_requests = spies.get_client_requests(method="tools/list") - assert len(list_tools_requests) == 1 - assert list_tools_requests[0].params is not None - assert list_tools_requests[0].params["cursor"] == "" + @server.prompt() + async def test_prompt(name: str) -> str: # pragma: no cover + """Test prompt""" + return f"Hello, {name}!" + @server.resource("resource://test/{name}") + async def test_template(name: str) -> str: # pragma: no cover + """Test resource template""" + return f"Data for {name}" -async def test_list_resources_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): - """Test that the cursor parameter is accepted for list_resources - and that it is correctly passed to the server. + return server + + +@pytest.mark.parametrize( + "method_name,request_method", + [ + ("list_tools", "tools/list"), + ("list_resources", "resources/list"), + ("list_prompts", "prompts/list"), + ("list_resource_templates", "resources/templates/list"), + ], +) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +async def test_list_methods_cursor_parameter( + stream_spy: Callable[[], StreamSpyCollection], + full_featured_server: FastMCP, + method_name: str, + request_method: str, +): + """Test that the cursor parameter is accepted and correctly passed to the server. + + Covers: list_tools, list_resources, list_prompts, list_resource_templates See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format """ - server = FastMCP("test") - - # Create a test resource - @server.resource("resource://test/data") - async def test_resource() -> str: - """Test resource""" - return "Test data" - - async with create_session(server._mcp_server) as client_session: + async with create_session(full_featured_server._mcp_server) as client_session: spies = stream_spy() # Test without cursor parameter (omitted) - _ = await client_session.list_resources() - list_resources_requests = spies.get_client_requests(method="resources/list") - assert len(list_resources_requests) == 1 - assert list_resources_requests[0].params is None + method = getattr(client_session, method_name) + _ = await method() + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is None spies.clear() # Test with cursor=None - _ = await client_session.list_resources(cursor=None) - list_resources_requests = spies.get_client_requests(method="resources/list") - assert len(list_resources_requests) == 1 - assert list_resources_requests[0].params is None + _ = await method(cursor=None) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is None spies.clear() # Test with cursor as string - _ = await client_session.list_resources(cursor="some_cursor") - list_resources_requests = spies.get_client_requests(method="resources/list") - assert len(list_resources_requests) == 1 - assert list_resources_requests[0].params is not None - assert list_resources_requests[0].params["cursor"] == "some_cursor" + _ = await method(cursor="some_cursor_value") + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is not None + assert requests[0].params["cursor"] == "some_cursor_value" spies.clear() # Test with empty string cursor - _ = await client_session.list_resources(cursor="") - list_resources_requests = spies.get_client_requests(method="resources/list") - assert len(list_resources_requests) == 1 - assert list_resources_requests[0].params is not None - assert list_resources_requests[0].params["cursor"] == "" - - -async def test_list_prompts_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): - """Test that the cursor parameter is accepted for list_prompts - and that it is correctly passed to the server. - See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format + _ = await method(cursor="") + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is not None + assert requests[0].params["cursor"] == "" + + +@pytest.mark.parametrize( + "method_name,request_method", + [ + ("list_tools", "tools/list"), + ("list_resources", "resources/list"), + ("list_prompts", "prompts/list"), + ("list_resource_templates", "resources/templates/list"), + ], +) +async def test_list_methods_params_parameter( + stream_spy: Callable[[], StreamSpyCollection], + full_featured_server: FastMCP, + method_name: str, + request_method: str, +): + """Test that the params parameter works correctly for list methods. + + Covers: list_tools, list_resources, list_prompts, list_resource_templates + + This tests the new params parameter API (non-deprecated) to ensure + it correctly handles all parameter combinations. """ - server = FastMCP("test") - - # Create a test prompt - @server.prompt() - async def test_prompt(name: str) -> str: - """Test prompt""" - return f"Hello, {name}!" - - async with create_session(server._mcp_server) as client_session: + async with create_session(full_featured_server._mcp_server) as client_session: spies = stream_spy() + method = getattr(client_session, method_name) - # Test without cursor parameter (omitted) - _ = await client_session.list_prompts() - list_prompts_requests = spies.get_client_requests(method="prompts/list") - assert len(list_prompts_requests) == 1 - assert list_prompts_requests[0].params is None + # Test without params parameter (omitted) + _ = await method() + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is None spies.clear() - # Test with cursor=None - _ = await client_session.list_prompts(cursor=None) - list_prompts_requests = spies.get_client_requests(method="prompts/list") - assert len(list_prompts_requests) == 1 - assert list_prompts_requests[0].params is None + # Test with params=None + _ = await method(params=None) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is None spies.clear() - # Test with cursor as string - _ = await client_session.list_prompts(cursor="some_cursor") - list_prompts_requests = spies.get_client_requests(method="prompts/list") - assert len(list_prompts_requests) == 1 - assert list_prompts_requests[0].params is not None - assert list_prompts_requests[0].params["cursor"] == "some_cursor" + # Test with empty params (for strict servers) + _ = await method(params=types.PaginatedRequestParams()) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is not None + assert requests[0].params.get("cursor") is None spies.clear() - # Test with empty string cursor - _ = await client_session.list_prompts(cursor="") - list_prompts_requests = spies.get_client_requests(method="prompts/list") - assert len(list_prompts_requests) == 1 - assert list_prompts_requests[0].params is not None - assert list_prompts_requests[0].params["cursor"] == "" - - -async def test_list_resource_templates_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): - """Test that the cursor parameter is accepted for list_resource_templates - and that it is correctly passed to the server. - - See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format + # Test with params containing cursor + _ = await method(params=types.PaginatedRequestParams(cursor="some_cursor_value")) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is not None + assert requests[0].params["cursor"] == "some_cursor_value" + + +@pytest.mark.parametrize( + "method_name", + [ + "list_tools", + "list_resources", + "list_prompts", + "list_resource_templates", + ], +) +async def test_list_methods_raises_error_when_both_cursor_and_params_provided( + full_featured_server: FastMCP, + method_name: str, +): + """Test that providing both cursor and params raises ValueError. + + Covers: list_tools, list_resources, list_prompts, list_resource_templates + + When both cursor and params are provided, a ValueError should be raised + to prevent ambiguity. """ - server = FastMCP("test") + async with create_session(full_featured_server._mcp_server) as client_session: + method = getattr(client_session, method_name) - # Create a test resource template - @server.resource("resource://test/{name}") - async def test_template(name: str) -> str: - """Test resource template""" - return f"Data for {name}" + # Call with both cursor and params - should raise ValueError + with pytest.raises(ValueError, match="Cannot specify both cursor and params"): + await method( + cursor="old_cursor", + params=types.PaginatedRequestParams(cursor="new_cursor"), + ) - async with create_session(server._mcp_server) as client_session: - spies = stream_spy() - # Test without cursor parameter (omitted) - _ = await client_session.list_resource_templates() - list_templates_requests = spies.get_client_requests(method="resources/templates/list") - assert len(list_templates_requests) == 1 - assert list_templates_requests[0].params is None +async def test_list_tools_with_strict_server_validation(): + """Test that list_tools works with strict servers require a params field, + even if it is empty. - spies.clear() + Some MCP servers may implement strict JSON-RPC validation that requires + the params field to always be present in requests, even if empty {}. - # Test with cursor=None - _ = await client_session.list_resource_templates(cursor=None) - list_templates_requests = spies.get_client_requests(method="resources/templates/list") - assert len(list_templates_requests) == 1 - assert list_templates_requests[0].params is None + This test ensures such servers are supported by the client SDK for list_resources + requests without a cursor. + """ - spies.clear() + server = Server("strict_server") - # Test with cursor as string - _ = await client_session.list_resource_templates(cursor="some_cursor") - list_templates_requests = spies.get_client_requests(method="resources/templates/list") - assert len(list_templates_requests) == 1 - assert list_templates_requests[0].params is not None - assert list_templates_requests[0].params["cursor"] == "some_cursor" + @server.list_tools() + async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: # pragma: no cover + """Strict handler that validates params field exists""" - spies.clear() + # Simulate strict server validation + if request.params is None: + raise ValueError( + "Strict server validation failed: params field must be present. " + "Expected params: {} for requests without cursor." + ) - # Test with empty string cursor - _ = await client_session.list_resource_templates(cursor="") - list_templates_requests = spies.get_client_requests(method="resources/templates/list") - assert len(list_templates_requests) == 1 - assert list_templates_requests[0].params is not None - assert list_templates_requests[0].params["cursor"] == "" + # Return empty tools list + return ListToolsResult(tools=[]) + + async with create_session(server) as client_session: + # Use params to explicitly send params: {} for strict server compatibility + result = await client_session.list_tools(params=types.PaginatedRequestParams()) + assert result is not None diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index f298ee2871..de058eb061 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Any, Literal import pytest @@ -47,11 +47,28 @@ async def test_tool_with_log( ) return True + @server.tool("test_tool_with_log_extra") + async def test_tool_with_log_extra( + message: str, + level: Literal["debug", "info", "warning", "error"], + logger: str, + extra_string: str, + extra_dict: dict[str, Any], + ) -> bool: + """Send a log notification to the client with extra fields.""" + await server.get_context().log( + level=level, + message=message, + logger_name=logger, + extra={"extra_string": extra_string, "extra_dict": extra_dict}, + ) + return True + # Create a message handler to catch exceptions async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message async with create_session( @@ -74,10 +91,30 @@ async def message_handler( "logger": "test_logger", }, ) + log_result_with_extra = await client_session.call_tool( + "test_tool_with_log_extra", + { + "message": "Test log message", + "level": "info", + "logger": "test_logger", + "extra_string": "example", + "extra_dict": {"a": 1, "b": 2, "c": 3}, + }, + ) assert log_result.isError is False - assert len(logging_collector.log_messages) == 1 + assert log_result_with_extra.isError is False + assert len(logging_collector.log_messages) == 2 # Create meta object with related_request_id added dynamically log = logging_collector.log_messages[0] assert log.level == "info" assert log.logger == "test_logger" assert log.data == "Test log message" + + log_with_extra = logging_collector.log_messages[1] + assert log_with_extra.level == "info" + assert log_with_extra.logger == "test_logger" + assert log_with_extra.data == { + "message": "Test log message", + "extra_string": "example", + "extra_dict": {"a": 1, "b": 2, "c": 3}, + } diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 88e64711b5..7500abee73 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -8,7 +8,6 @@ import json import multiprocessing import socket -import time from collections.abc import Generator import pytest @@ -19,12 +18,13 @@ from starlette.routing import Route from mcp import ClientSession, types -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.session import RequestResponder from mcp.types import ClientNotification, RootsListChangedNotification +from tests.test_helpers import wait_for_server -def create_non_sdk_server_app() -> Starlette: +def create_non_sdk_server_app() -> Starlette: # pragma: no cover """Create a minimal server that doesn't follow SDK conventions.""" async def handle_mcp_request(request: Request) -> Response: @@ -67,7 +67,7 @@ async def handle_mcp_request(request: Request) -> Response: return app -def run_non_sdk_server(port: int) -> None: +def run_non_sdk_server(port: int) -> None: # pragma: no cover """Run the non-SDK server in a separate process.""" app = create_non_sdk_server_app() config = uvicorn.Config( @@ -95,14 +95,9 @@ def non_sdk_server(non_sdk_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be ready - start_time = time.time() - while time.time() - start_time < 10: - try: - with socket.create_connection(("127.0.0.1", non_sdk_server_port), timeout=0.1): - break - except (TimeoutError, ConnectionRefusedError): - time.sleep(0.1) - else: + try: # pragma: no cover + wait_for_server(non_sdk_server_port, timeout=10.0) + except TimeoutError: # pragma: no cover proc.kill() proc.join(timeout=2) pytest.fail("Server failed to start within 10 seconds") @@ -125,14 +120,14 @@ async def test_non_compliant_notification_response(non_sdk_server: None, non_sdk server_url = f"http://127.0.0.1:{non_sdk_server_port}/mcp" returned_exception = None - async def message_handler( + async def message_handler( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ): nonlocal returned_exception if isinstance(message, Exception): returned_exception = message - async with streamablehttp_client(server_url) as (read_stream, write_stream, _): + async with streamable_http_client(server_url) as (read_stream, write_stream, _): async with ClientSession( read_stream, write_stream, @@ -146,5 +141,5 @@ async def message_handler( ClientNotification(RootsListChangedNotification(method="notifications/roots/list_changed")) ) - if returned_exception: + if returned_exception: # pragma: no cover pytest.fail(f"Server encountered an exception: {returned_exception}") diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 4e649b0eb2..e4a06b7f82 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -19,9 +19,27 @@ def bypass_server_output_validation(): This simulates a malicious or non-compliant server that doesn't validate its outputs, allowing us to test client-side validation. """ - # Patch jsonschema.validate in the server module to disable all validation - with patch("mcp.server.lowlevel.server.jsonschema.validate"): - # The mock will simply return None (do nothing) for all validation calls + import jsonschema + + # Save the original validate function + original_validate = jsonschema.validate + + # Create a mock that tracks which module is calling it + def selective_mock(instance: Any = None, schema: Any = None, *args: Any, **kwargs: Any) -> None: + import inspect + + # Check the call stack to see where this is being called from + for frame_info in inspect.stack(): + # If called from the server module, skip validation + # TODO: fix this as it's a rather gross workaround and will eventually break + # Normalize path separators for cross-platform compatibility + normalized_path = frame_info.filename.replace("\\", "/") + if "mcp/server/lowlevel/server.py" in normalized_path: + return None + # Otherwise, use the real validation (for client-side) + return original_validate(instance=instance, schema=schema, *args, **kwargs) + + with patch("jsonschema.validate", selective_mock): yield diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index e0b4815817..cc6c5059fd 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -19,7 +19,9 @@ async def test_send_request_stream_cleanup(): # Create a mock session with the minimal required functionality class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]): - async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: + async def _send_response( + self, request_id: RequestId, response: SendResultT | ErrorData + ) -> None: # pragma: no cover pass # Create streams diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index a3f6affda8..733364a767 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -8,8 +8,10 @@ from mcp.types import ( CreateMessageRequestParams, CreateMessageResult, + CreateMessageResultWithTools, SamplingMessage, TextContent, + ToolUseContent, ) @@ -56,3 +58,79 @@ async def test_sampling_tool(message: str): assert result.isError is True assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" + + +@pytest.mark.anyio +async def test_create_message_backwards_compat_single_content(): + """Test backwards compatibility: create_message without tools returns single content.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + # Callback returns single content (text) + callback_return = CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Hello from LLM"), + model="test-model", + stopReason="endTurn", + ) + + async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + return callback_return + + @server.tool("test_backwards_compat") + async def test_tool(message: str): + # Call create_message WITHOUT tools + result = await server.get_context().session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], + max_tokens=100, + ) + # Backwards compat: result should be CreateMessageResult + assert isinstance(result, CreateMessageResult) + # Content should be single (not a list) - this is the key backwards compat check + assert isinstance(result.content, TextContent) + assert result.content.text == "Hello from LLM" + # CreateMessageResult should NOT have content_as_list (that's on WithTools) + assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None)) + return True + + async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session: + result = await client_session.call_tool("test_backwards_compat", {"message": "Test"}) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + +@pytest.mark.anyio +async def test_create_message_result_with_tools_type(): + """Test that CreateMessageResultWithTools supports content_as_list.""" + # Test the type itself, not the overload (overload requires client capability setup) + result = CreateMessageResultWithTools( + role="assistant", + content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}), + model="test-model", + stopReason="toolUse", + ) + + # CreateMessageResultWithTools should have content_as_list + content_list = result.content_as_list + assert len(content_list) == 1 + assert content_list[0].type == "tool_use" + + # It should also work with array content + result_array = CreateMessageResultWithTools( + role="assistant", + content=[ + TextContent(type="text", text="Let me check the weather"), + ToolUseContent(type="tool_use", id="call_456", name="get_weather", input={"city": "NYC"}), + ], + model="test-model", + stopReason="toolUse", + ) + content_list_array = result_array.content_as_list + assert len(content_list_array) == 2 + assert content_list_array[0].type == "text" + assert content_list_array[1].type == "tool_use" diff --git a/tests/client/test_scope_bug_1630.py b/tests/client/test_scope_bug_1630.py new file mode 100644 index 0000000000..7884718c1e --- /dev/null +++ b/tests/client/test_scope_bug_1630.py @@ -0,0 +1,166 @@ +""" +Regression test for issue #1630: OAuth2 scope incorrectly set to resource_metadata URL. + +This test verifies that when a 401 response contains both resource_metadata and scope +in the WWW-Authenticate header, the actual scope is used (not the resource_metadata URL). +""" + +from unittest import mock + +import httpx +import pytest +from pydantic import AnyUrl + +from mcp.client.auth import OAuthClientProvider +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +class MockTokenStorage: + """Mock token storage for testing.""" + + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens # pragma: no cover + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info # pragma: no cover + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info # pragma: no cover + + +@pytest.mark.anyio +async def test_401_uses_www_auth_scope_not_resource_metadata_url(): + """ + Regression test for #1630: Ensure scope is extracted from WWW-Authenticate header, + not the resource_metadata URL. + + When a 401 response contains: + WWW-Authenticate: Bearer resource_metadata="https://.../", scope="read write" + + The client should use "read write" as the scope, NOT the resource_metadata URL. + """ + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + client_metadata = OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + client_name="Test Client", + ) + + provider = OAuthClientProvider( + server_url="https://api.example.com/mcp", + client_metadata=client_metadata, + storage=MockTokenStorage(), + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Pre-set client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request (no auth header yet) + await auth_flow.__anext__() + + # 401 response with BOTH resource_metadata URL and scope in WWW-Authenticate + # This is the key: the bug would use the URL as scope instead of "read write" + resource_metadata_url = "https://api.example.com/.well-known/oauth-protected-resource" + expected_scope = "read write" + + response_401 = httpx.Response( + 401, + headers={"WWW-Authenticate": (f'Bearer resource_metadata="{resource_metadata_url}", scope="{expected_scope}"')}, + request=test_request, + ) + + # Send 401, expect PRM discovery request + prm_request = await auth_flow.asend(response_401) + assert ".well-known/oauth-protected-resource" in str(prm_request.url) + + # PRM response with scopes_supported (these should be overridden by WWW-Auth scope) + prm_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/mcp", ' + b'"authorization_servers": ["https://auth.example.com/"], ' + b'"scopes_supported": ["fallback:scope1", "fallback:scope2"]}' + ), + request=prm_request, + ) + + # Send PRM response, expect OAuth metadata discovery + oauth_metadata_request = await auth_flow.asend(prm_response) + assert ".well-known/oauth-authorization-server" in str(oauth_metadata_request.url) + + # OAuth metadata response + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com/", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token"}' + ), + request=oauth_metadata_request, + ) + + # Mock authorization to skip interactive flow + provider._perform_authorization_code_grant = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + + # Send OAuth metadata response, expect token request + token_request = await auth_flow.asend(oauth_metadata_response) + assert "token" in str(token_request.url) + + # NOW CHECK: The scope should be the WWW-Authenticate scope, NOT the URL + # This is where the bug manifested - scope was set to resource_metadata_url + actual_scope = provider.context.client_metadata.scope + + # This assertion would FAIL on main (scope would be the URL) + # but PASS on the fix branch (scope is "read write") + assert actual_scope == expected_scope, ( + f"Expected scope to be '{expected_scope}' from WWW-Authenticate header, " + f"but got '{actual_scope}'. " + f"If scope is '{resource_metadata_url}', the bug from #1630 is present." + ) + + # Verify it's definitely not the URL (explicit check for the bug) + assert actual_scope != resource_metadata_url, ( + f"BUG #1630: Scope was incorrectly set to resource_metadata URL '{resource_metadata_url}' " + f"instead of the actual scope '{expected_scope}'" + ) + + # Complete the flow to properly release the lock + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer test_token" + + # Finish the flow + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 53b60fce61..eb2683fbdb 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -11,6 +11,7 @@ from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, + CallToolResult, ClientNotification, ClientRequest, Implementation, @@ -23,6 +24,7 @@ JSONRPCResponse, ServerCapabilities, ServerResult, + TextContent, ) @@ -80,7 +82,7 @@ async def mock_server(): ) # Create a message handler to catch exceptions - async def message_handler( + async def message_handler( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -424,7 +426,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None - async def custom_sampling_callback( + async def custom_sampling_callback( # pragma: no cover context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: @@ -434,7 +436,7 @@ async def custom_sampling_callback( model="test-model", ) - async def custom_list_roots_callback( + async def custom_list_roots_callback( # pragma: no cover context: RequestContext["ClientSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) @@ -492,8 +494,277 @@ async def mock_server(): # Assert that capabilities are properly set with custom callbacks assert received_capabilities is not None - assert received_capabilities.sampling is not None # Custom sampling callback provided + # Custom sampling callback provided + assert received_capabilities.sampling is not None assert isinstance(received_capabilities.sampling, types.SamplingCapability) - assert received_capabilities.roots is not None # Custom list_roots callback provided + # Default sampling capabilities (no tools) + assert received_capabilities.sampling.tools is None + # Custom list_roots callback provided + assert received_capabilities.roots is not None assert isinstance(received_capabilities.roots, types.RootsCapability) - assert received_capabilities.roots.listChanged is True # Should be True for custom callback + # Should be True for custom callback + assert received_capabilities.roots.listChanged is True + + +@pytest.mark.anyio +async def test_client_capabilities_with_sampling_tools(): + """Test that sampling capabilities with tools are properly advertised""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + async def custom_sampling_callback( # pragma: no cover + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(type="text", text="test"), + model="test-model", + ) + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + sampling_callback=custom_sampling_callback, + sampling_capabilities=types.SamplingCapability(tools=types.SamplingToolsCapability()), + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that sampling capabilities with tools are properly advertised + assert received_capabilities is not None + assert received_capabilities.sampling is not None + assert isinstance(received_capabilities.sampling, types.SamplingCapability) + # Tools capability should be present + assert received_capabilities.sampling.tools is not None + assert isinstance(received_capabilities.sampling.tools, types.SamplingToolsCapability) + + +@pytest.mark.anyio +async def test_get_server_capabilities(): + """Test that get_server_capabilities returns None before init and capabilities after""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + expected_capabilities = ServerCapabilities( + logging=types.LoggingCapability(), + prompts=types.PromptsCapability(listChanged=True), + resources=types.ResourcesCapability(subscribe=True, listChanged=True), + tools=types.ToolsCapability(listChanged=False), + ) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=expected_capabilities, + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + assert session.get_server_capabilities() is None + + tg.start_soon(mock_server) + await session.initialize() + + capabilities = session.get_server_capabilities() + assert capabilities is not None + assert capabilities == expected_capabilities + assert capabilities.logging is not None + assert capabilities.prompts is not None + assert capabilities.prompts.listChanged is True + assert capabilities.resources is not None + assert capabilities.resources.subscribe is True + assert capabilities.tools is not None + assert capabilities.tools.listChanged is False + + +@pytest.mark.anyio +@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) +async def test_client_tool_call_with_meta(meta: dict[str, Any] | None): + """Test that client tool call requests can include metadata""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + mocked_tool = types.Tool(name="sample_tool", inputSchema={}) + + async def mock_server(): + # Receive initialization request from client + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + # Answer initialization request + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Wait for the client to send a 'tools/call' request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + assert jsonrpc_request.root.method == "tools/call" + + if meta is not None: + assert jsonrpc_request.root.params + assert "_meta" in jsonrpc_request.root.params + assert jsonrpc_request.root.params["_meta"] == meta + + result = ServerResult( + CallToolResult(content=[TextContent(type="text", text="Called successfully")], isError=False) + ) + + # Send the tools/call result + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Wait for the tools/list request from the client + # The client requires this step to validate the tool output schema + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + assert jsonrpc_request.root.method == "tools/list" + + result = types.ListToolsResult(tools=[mocked_tool]) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index c38cfeabcc..b03fe9ca88 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -5,7 +5,12 @@ import mcp from mcp import types -from mcp.client.session_group import ClientSessionGroup, SseServerParameters, StreamableHttpParameters +from mcp.client.session_group import ( + ClientSessionGroup, + ClientSessionParameters, + SseServerParameters, + StreamableHttpParameters, +) from mcp.client.stdio import StdioServerParameters from mcp.shared.exceptions import McpError @@ -50,7 +55,7 @@ async def test_call_tool(self): mock_session = mock.AsyncMock() # --- Prepare Session Group --- - def hook(name: str, server_info: types.Implementation) -> str: + def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cover return f"{(server_info.name)}-{name}" mcp_session_group = ClientSessionGroup(component_name_hook=hook) @@ -62,7 +67,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # --- Test Execution --- result = await mcp_session_group.call_tool( name="server1-my_tool", - args={ + arguments={ "name": "value1", "args": {}, }, @@ -73,6 +78,9 @@ def hook(name: str, server_info: types.Implementation) -> str: mock_session.call_tool.assert_called_once_with( "my_tool", {"name": "value1", "args": {}}, + read_timeout_seconds=None, + progress_callback=None, + meta=None, ) async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack): @@ -265,14 +273,14 @@ async def test_disconnect_non_existent_server(self): "mcp.client.session_group.mcp.stdio_client", ), ( - SseServerParameters(url="http://test.com/sse", timeout=10), + SseServerParameters(url="http://test.com/sse", timeout=10.0), "sse", "mcp.client.session_group.sse_client", ), # url, headers, timeout, sse_read_timeout ( StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), "streamablehttp", - "mcp.client.session_group.streamablehttp_client", + "mcp.client.session_group.streamable_http_client", ), # url, headers, timeout, sse_read_timeout, terminate_on_close ], ) @@ -288,7 +296,7 @@ async def test_establish_session_parameterized( mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") - # streamablehttp_client's __aenter__ returns three values + # streamable_http_client's __aenter__ returns three values if client_type_name == "streamablehttp": mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") mock_client_cm_instance.__aenter__.return_value = ( @@ -329,7 +337,7 @@ async def test_establish_session_parameterized( ( returned_server_info, returned_session, - ) = await group._establish_session(server_params_instance) + ) = await group._establish_session(server_params_instance, ClientSessionParameters()) # --- Assertions --- # 1. Assert the correct specific client function was called @@ -344,20 +352,31 @@ async def test_establish_session_parameterized( timeout=server_params_instance.timeout, sse_read_timeout=server_params_instance.sse_read_timeout, ) - elif client_type_name == "streamablehttp": + elif client_type_name == "streamablehttp": # pragma: no branch assert isinstance(server_params_instance, StreamableHttpParameters) - mock_specific_client_func.assert_called_once_with( - url=server_params_instance.url, - headers=server_params_instance.headers, - timeout=server_params_instance.timeout, - sse_read_timeout=server_params_instance.sse_read_timeout, - terminate_on_close=server_params_instance.terminate_on_close, - ) + # Verify streamable_http_client was called with url, httpx_client, and terminate_on_close + # The http_client is created by the real create_mcp_http_client + import httpx + + call_args = mock_specific_client_func.call_args + assert call_args.kwargs["url"] == server_params_instance.url + assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close + assert isinstance(call_args.kwargs["http_client"], httpx.AsyncClient) mock_client_cm_instance.__aenter__.assert_awaited_once() # 2. Assert ClientSession was called correctly - mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_ClientSession_class.assert_called_once_with( + mock_read_stream, + mock_write_stream, + read_timeout_seconds=None, + sampling_callback=None, + elicitation_callback=None, + list_roots_callback=None, + logging_callback=None, + message_handler=None, + client_info=None, + ) mock_raw_session_cm.__aenter__.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once() diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 69dad4846a..ba58da7321 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,3 +1,4 @@ +import errno import os import shutil import sys @@ -53,7 +54,7 @@ async def test_stdio_client(): read_messages: list[JSONRPCMessage] = [] async with read_stream: async for message in read_stream: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message read_messages.append(message.message) @@ -68,7 +69,7 @@ async def test_stdio_client(): @pytest.mark.anyio async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" - server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"]) + server_params = StdioServerParameters(command=sys.executable, args=["-c", "non-existent-file.py"]) async with stdio_client(server_params) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # The session should raise an error when the connection closes @@ -90,17 +91,12 @@ async def test_stdio_client_nonexistent_command(): ) # Should raise an error when trying to start the process - with pytest.raises(Exception) as exc_info: + with pytest.raises(OSError) as exc_info: async with stdio_client(server_params) as (_, _): - pass + pass # pragma: no cover - # The error should indicate the command was not found - error_message = str(exc_info.value) - assert ( - "nonexistent" in error_message - or "not found" in error_message.lower() - or "cannot find the file" in error_message.lower() # Windows error message - ) + # The error should indicate the command was not found (ENOENT: No such file or directory) + assert exc_info.value.errno == errno.ENOENT @pytest.mark.anyio @@ -148,7 +144,7 @@ async def test_stdio_client_universal_cleanup(): ) # Check if we timed out - if cancel_scope.cancelled_caught: + if cancel_scope.cancelled_caught: # pragma: no cover pytest.fail( "stdio_client cleanup timed out after 8.0 seconds. " "This indicates the cleanup mechanism is hanging and needs fixing." @@ -157,7 +153,7 @@ async def test_stdio_client_universal_cleanup(): @pytest.mark.anyio @pytest.mark.skipif(sys.platform == "win32", reason="Windows signal handling is different") -async def test_stdio_client_sigint_only_process(): +async def test_stdio_client_sigint_only_process(): # pragma: no cover """ Test cleanup with a process that ignores SIGTERM but responds to SIGINT. """ @@ -199,7 +195,7 @@ def sigint_handler(signum, frame): # Exit context triggers cleanup - this should not hang pass - if cancel_scope.cancelled_caught: + if cancel_scope.cancelled_caught: # pragma: no cover raise TimeoutError("Test timed out") end_time = time.time() @@ -212,7 +208,7 @@ def sigint_handler(signum, frame): f"Expected < {SIGTERM_IGNORING_PROCESS_TIMEOUT} seconds. " "This suggests the cleanup needs SIGINT/SIGKILL fallback." ) - except (TimeoutError, Exception) as e: + except (TimeoutError, Exception) as e: # pragma: no cover if isinstance(e, TimeoutError) or "timed out" in str(e): pytest.fail( f"stdio_client cleanup timed out after {SIGTERM_IGNORING_PROCESS_TIMEOUT} seconds " @@ -307,7 +303,7 @@ async def test_basic_child_process_cleanup(self): assert os.path.exists(parent_marker), "Parent process didn't start" # Verify child is writing - if os.path.exists(marker_file): + if os.path.exists(marker_file): # pragma: no branch initial_size = os.path.getsize(marker_file) await anyio.sleep(0.3) size_after_wait = os.path.getsize(marker_file) @@ -322,7 +318,7 @@ async def test_basic_child_process_cleanup(self): # Verify processes stopped await anyio.sleep(0.5) - if os.path.exists(marker_file): + if os.path.exists(marker_file): # pragma: no branch size_after_cleanup = os.path.getsize(marker_file) await anyio.sleep(0.5) final_size = os.path.getsize(marker_file) @@ -339,7 +335,7 @@ async def test_basic_child_process_cleanup(self): for f in [marker_file, parent_marker]: try: os.unlink(f) - except OSError: + except OSError: # pragma: no cover pass @pytest.mark.anyio @@ -410,7 +406,7 @@ async def test_nested_process_tree(self): # Verify all are writing for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): + if os.path.exists(file_path): # pragma: no branch initial_size = os.path.getsize(file_path) await anyio.sleep(0.3) new_size = os.path.getsize(file_path) @@ -424,7 +420,7 @@ async def test_nested_process_tree(self): # Verify all stopped await anyio.sleep(0.5) for file_path, name in [(parent_file, "parent"), (child_file, "child"), (grandchild_file, "grandchild")]: - if os.path.exists(file_path): + if os.path.exists(file_path): # pragma: no branch size1 = os.path.getsize(file_path) await anyio.sleep(0.3) size2 = os.path.getsize(file_path) @@ -437,7 +433,7 @@ async def test_nested_process_tree(self): for f in [parent_file, child_file, grandchild_file]: try: os.unlink(f) - except OSError: + except OSError: # pragma: no cover pass @pytest.mark.anyio @@ -491,7 +487,7 @@ def handle_term(sig, frame): await anyio.sleep(0.5) # Verify child is writing - if os.path.exists(marker_file): + if os.path.exists(marker_file): # pragma: no cover size1 = os.path.getsize(marker_file) await anyio.sleep(0.3) size2 = os.path.getsize(marker_file) @@ -504,7 +500,7 @@ def handle_term(sig, frame): # Verify child stopped await anyio.sleep(0.5) - if os.path.exists(marker_file): + if os.path.exists(marker_file): # pragma: no branch size3 = os.path.getsize(marker_file) await anyio.sleep(0.3) size4 = os.path.getsize(marker_file) @@ -516,7 +512,7 @@ def handle_term(sig, frame): # Clean up marker file try: os.unlink(marker_file) - except OSError: + except OSError: # pragma: no cover pass @@ -564,7 +560,7 @@ async def test_stdio_client_graceful_stdin_exit(): pytest.fail( "stdio_client cleanup timed out after 5.0 seconds. " "Process should have exited gracefully when stdin was closed." - ) + ) # pragma: no cover end_time = time.time() elapsed = end_time - start_time @@ -623,7 +619,7 @@ def sigterm_handler(signum, frame): pytest.fail( "stdio_client cleanup timed out after 7.0 seconds. " "Process should have been terminated via SIGTERM escalation." - ) + ) # pragma: no cover end_time = time.time() elapsed = end_time - start_time diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/tasks/__init__.py b/tests/experimental/tasks/__init__.py new file mode 100644 index 0000000000..6e8649d283 --- /dev/null +++ b/tests/experimental/tasks/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP task support.""" diff --git a/tests/experimental/tasks/client/__init__.py b/tests/experimental/tasks/client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py new file mode 100644 index 0000000000..f2def4e3a6 --- /dev/null +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -0,0 +1,331 @@ +"""Tests for client task capabilities declaration during initialization.""" + +import anyio +import pytest + +import mcp.types as types +from mcp import ClientCapabilities +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientRequest, + Implementation, + InitializeRequest, + InitializeResult, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + ServerResult, +) + + +@pytest.mark.anyio +async def test_client_capabilities_without_tasks(): + """Test that tasks capability is None when not provided.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities = None + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability is None when not provided + assert received_capabilities is not None + assert received_capabilities.tasks is None + + +@pytest.mark.anyio +async def test_client_capabilities_with_tasks(): + """Test that tasks capability is properly set when handlers are provided.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities: ClientCapabilities | None = None + + # Define custom handlers to trigger capability building (never actually called) + async def my_list_tasks_handler( + context: RequestContext[ClientSession, None], + params: types.PaginatedRequestParams | None, + ) -> types.ListTasksResult | types.ErrorData: + raise NotImplementedError + + async def my_cancel_task_handler( + context: RequestContext[ClientSession, None], + params: types.CancelTaskRequestParams, + ) -> types.CancelTaskResult | types.ErrorData: + raise NotImplementedError + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + # Create handlers container + task_handlers = ExperimentalTaskHandlers( + list_tasks=my_list_tasks_handler, + cancel_task=my_cancel_task_handler, + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability is properly set from handlers + assert received_capabilities is not None + assert received_capabilities.tasks is not None + assert isinstance(received_capabilities.tasks, types.ClientTasksCapability) + assert received_capabilities.tasks.list is not None + assert received_capabilities.tasks.cancel is not None + + +@pytest.mark.anyio +async def test_client_capabilities_auto_built_from_handlers(): + """Test that tasks capability is automatically built from provided handlers.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities: ClientCapabilities | None = None + + # Define custom handlers (not defaults) + async def my_list_tasks_handler( + context: RequestContext[ClientSession, None], + params: types.PaginatedRequestParams | None, + ) -> types.ListTasksResult | types.ErrorData: + raise NotImplementedError + + async def my_cancel_task_handler( + context: RequestContext[ClientSession, None], + params: types.CancelTaskRequestParams, + ) -> types.CancelTaskResult | types.ErrorData: + raise NotImplementedError + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + # Provide handlers via ExperimentalTaskHandlers + task_handlers = ExperimentalTaskHandlers( + list_tasks=my_list_tasks_handler, + cancel_task=my_cancel_task_handler, + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability was auto-built from handlers + assert received_capabilities is not None + assert received_capabilities.tasks is not None + assert received_capabilities.tasks.list is not None + assert received_capabilities.tasks.cancel is not None + # requests should be None since we didn't provide task-augmented handlers + assert received_capabilities.tasks.requests is None + + +@pytest.mark.anyio +async def test_client_capabilities_with_task_augmented_handlers(): + """Test that requests capability is built when augmented handlers are provided.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + received_capabilities: ClientCapabilities | None = None + + # Define task-augmented handler + async def my_augmented_sampling_handler( + context: RequestContext[ClientSession, None], + params: types.CreateMessageRequestParams, + task_metadata: types.TaskMetadata, + ) -> types.CreateTaskResult | types.ErrorData: + raise NotImplementedError + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + # Provide task-augmented sampling handler + task_handlers = ExperimentalTaskHandlers( + augmented_sampling=my_augmented_sampling_handler, + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that tasks capability includes requests.sampling + assert received_capabilities is not None + assert received_capabilities.tasks is not None + assert received_capabilities.tasks.requests is not None + assert received_capabilities.tasks.requests.sampling is not None + assert received_capabilities.tasks.requests.elicitation is None # Not provided diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py new file mode 100644 index 0000000000..86cea42ae1 --- /dev/null +++ b/tests/experimental/tasks/client/test_handlers.py @@ -0,0 +1,878 @@ +"""Tests for client-side task management handlers (server -> client requests). + +These tests verify that clients can handle task-related requests from servers: +- GetTaskRequest - server polling client's task status +- GetTaskPayloadRequest - server getting result from client's task +- ListTasksRequest - server listing client's tasks +- CancelTaskRequest - server cancelling client's task + +This is the inverse of the existing tests in test_tasks.py, which test +client -> server task requests. +""" + +from collections.abc import AsyncIterator +from dataclasses import dataclass + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +import mcp.types as types +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CancelTaskRequest, + CancelTaskRequestParams, + CancelTaskResult, + ClientResult, + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + CreateTaskResult, + ElicitRequest, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + ErrorData, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + SamplingMessage, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, +) + +# Buffer size for test streams +STREAM_BUFFER_SIZE = 10 + + +@dataclass +class ClientTestStreams: + """Bidirectional message streams for client/server communication in tests.""" + + server_send: MemoryObjectSendStream[SessionMessage] + server_receive: MemoryObjectReceiveStream[SessionMessage] + client_send: MemoryObjectSendStream[SessionMessage] + client_receive: MemoryObjectReceiveStream[SessionMessage] + + +@pytest.fixture +async def client_streams() -> AsyncIterator[ClientTestStreams]: + """Create bidirectional message streams for client tests. + + Automatically closes all streams after the test completes. + """ + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage]( + STREAM_BUFFER_SIZE + ) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage]( + STREAM_BUFFER_SIZE + ) + + streams = ClientTestStreams( + server_send=server_to_client_send, + server_receive=client_to_server_receive, + client_send=client_to_server_send, + client_receive=server_to_client_receive, + ) + + yield streams + + # Cleanup + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +async def _default_message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, +) -> None: + """Default message handler that ignores messages (tests handle them explicitly).""" + ... + + +@pytest.mark.anyio +async def test_client_handles_get_task_request(client_streams: ClientTestStreams) -> None: + """Test that client can respond to GetTaskRequest from server.""" + with anyio.fail_after(10): + store = InMemoryTaskStore() + received_task_id: str | None = None + + async def get_task_handler( + context: RequestContext[ClientSession, None], + params: GetTaskRequestParams, + ) -> GetTaskResult | ErrorData: + nonlocal received_task_id + received_task_id = params.taskId + task = await store.get_task(params.taskId) + assert task is not None, f"Test setup error: task {params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-123") + + task_handlers = ExperimentalTaskHandlers(get_task=get_task_handler) + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = GetTaskRequest(params=GetTaskRequestParams(taskId="test-task-123")) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-1", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + assert response.id == "req-1" + + result = GetTaskResult.model_validate(response.result) + assert result.taskId == "test-task-123" + assert result.status == "working" + assert received_task_id == "test-task-123" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_client_handles_get_task_result_request(client_streams: ClientTestStreams) -> None: + """Test that client can respond to GetTaskPayloadRequest from server.""" + with anyio.fail_after(10): + store = InMemoryTaskStore() + + async def get_task_result_handler( + context: RequestContext[ClientSession, None], + params: GetTaskPayloadRequestParams, + ) -> GetTaskPayloadResult | ErrorData: + result = await store.get_result(params.taskId) + assert result is not None, f"Test setup error: result for {params.taskId} should exist" + assert isinstance(result, types.CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + await store.create_task(TaskMetadata(ttl=60000), task_id="test-task-456") + await store.store_result( + "test-task-456", + types.CallToolResult(content=[TextContent(type="text", text="Task completed successfully!")]), + ) + await store.update_task("test-task-456", status="completed") + + task_handlers = ExperimentalTaskHandlers(get_task_result=get_task_result_handler) + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId="test-task-456")) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-2", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + assert isinstance(response.result, dict) + result_dict = response.result + assert "content" in result_dict + assert len(result_dict["content"]) == 1 + assert result_dict["content"][0]["text"] == "Task completed successfully!" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_client_handles_list_tasks_request(client_streams: ClientTestStreams) -> None: + """Test that client can respond to ListTasksRequest from server.""" + with anyio.fail_after(10): + store = InMemoryTaskStore() + + async def list_tasks_handler( + context: RequestContext[ClientSession, None], + params: types.PaginatedRequestParams | None, + ) -> ListTasksResult | ErrorData: + cursor = params.cursor if params else None + tasks_list, next_cursor = await store.list_tasks(cursor=cursor) + return ListTasksResult(tasks=tasks_list, nextCursor=next_cursor) + + await store.create_task(TaskMetadata(ttl=60000), task_id="task-1") + await store.create_task(TaskMetadata(ttl=60000), task_id="task-2") + + task_handlers = ExperimentalTaskHandlers(list_tasks=list_tasks_handler) + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = ListTasksRequest() + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-3", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + result = ListTasksResult.model_validate(response.result) + assert len(result.tasks) == 2 + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_client_handles_cancel_task_request(client_streams: ClientTestStreams) -> None: + """Test that client can respond to CancelTaskRequest from server.""" + with anyio.fail_after(10): + store = InMemoryTaskStore() + + async def cancel_task_handler( + context: RequestContext[ClientSession, None], + params: CancelTaskRequestParams, + ) -> CancelTaskResult | ErrorData: + task = await store.get_task(params.taskId) + assert task is not None, f"Test setup error: task {params.taskId} should exist" + await store.update_task(params.taskId, status="cancelled") + updated = await store.get_task(params.taskId) + assert updated is not None + return CancelTaskResult( + taskId=updated.taskId, + status=updated.status, + createdAt=updated.createdAt, + lastUpdatedAt=updated.lastUpdatedAt, + ttl=updated.ttl, + ) + + await store.create_task(TaskMetadata(ttl=60000), task_id="task-to-cancel") + + task_handlers = ExperimentalTaskHandlers(cancel_task=cancel_task_handler) + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = CancelTaskRequest(params=CancelTaskRequestParams(taskId="task-to-cancel")) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-4", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + result = CancelTaskResult.model_validate(response.result) + assert result.taskId == "task-to-cancel" + assert result.status == "cancelled" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_client_task_augmented_sampling(client_streams: ClientTestStreams) -> None: + """Test that client can handle task-augmented sampling request from server.""" + with anyio.fail_after(10): + store = InMemoryTaskStore() + sampling_completed = Event() + created_task_id: list[str | None] = [None] + background_tg: list[TaskGroup | None] = [None] + + async def task_augmented_sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + task_metadata: TaskMetadata, + ) -> CreateTaskResult: + task = await store.create_task(task_metadata) + created_task_id[0] = task.taskId + + async def do_sampling() -> None: + result = CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Sampled response"), + model="test-model", + stopReason="endTurn", + ) + await store.store_result(task.taskId, result) + await store.update_task(task.taskId, status="completed") + sampling_completed.set() + + assert background_tg[0] is not None + background_tg[0].start_soon(do_sampling) + return CreateTaskResult(task=task) + + async def get_task_handler( + context: RequestContext[ClientSession, None], + params: GetTaskRequestParams, + ) -> GetTaskResult | ErrorData: + task = await store.get_task(params.taskId) + assert task is not None, f"Test setup error: task {params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + async def get_task_result_handler( + context: RequestContext[ClientSession, None], + params: GetTaskPayloadRequestParams, + ) -> GetTaskPayloadResult | ErrorData: + result = await store.get_result(params.taskId) + assert result is not None, f"Test setup error: result for {params.taskId} should exist" + assert isinstance(result, CreateMessageResult) + return GetTaskPayloadResult(**result.model_dump()) + + task_handlers = ExperimentalTaskHandlers( + augmented_sampling=task_augmented_sampling_callback, + get_task=get_task_handler, + get_task_result=get_task_result_handler, + ) + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + background_tg[0] = tg + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Step 1: Server sends task-augmented CreateMessageRequest + typed_request = CreateMessageRequest( + params=CreateMessageRequestParams( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + maxTokens=100, + task=TaskMetadata(ttl=60000), + ) + ) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-sampling", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Step 2: Client responds with CreateTaskResult + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + task_result = CreateTaskResult.model_validate(response.result) + task_id = task_result.task.taskId + assert task_id == created_task_id[0] + + # Step 3: Wait for background sampling + await sampling_completed.wait() + + # Step 4: Server polls task status + typed_poll = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id)) + poll_request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-poll", + **typed_poll.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + + poll_response_msg = await client_streams.server_receive.receive() + poll_response = poll_response_msg.message.root + assert isinstance(poll_response, types.JSONRPCResponse) + + status = GetTaskResult.model_validate(poll_response.result) + assert status.status == "completed" + + # Step 5: Server gets result + typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id)) + result_request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-result", + **typed_result_req.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + + result_response_msg = await client_streams.server_receive.receive() + result_response = result_response_msg.message.root + assert isinstance(result_response, types.JSONRPCResponse) + + assert isinstance(result_response.result, dict) + assert result_response.result["role"] == "assistant" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_client_task_augmented_elicitation(client_streams: ClientTestStreams) -> None: + """Test that client can handle task-augmented elicitation request from server.""" + with anyio.fail_after(10): + store = InMemoryTaskStore() + elicitation_completed = Event() + created_task_id: list[str | None] = [None] + background_tg: list[TaskGroup | None] = [None] + + async def task_augmented_elicitation_callback( + context: RequestContext[ClientSession, None], + params: ElicitRequestParams, + task_metadata: TaskMetadata, + ) -> CreateTaskResult | ErrorData: + task = await store.create_task(task_metadata) + created_task_id[0] = task.taskId + + async def do_elicitation() -> None: + # Simulate user providing elicitation response + result = ElicitResult(action="http://github.com/accept", content={"name": "Test User"}) + await store.store_result(task.taskId, result) + await store.update_task(task.taskId, status="completed") + elicitation_completed.set() + + assert background_tg[0] is not None + background_tg[0].start_soon(do_elicitation) + return CreateTaskResult(task=task) + + async def get_task_handler( + context: RequestContext[ClientSession, None], + params: GetTaskRequestParams, + ) -> GetTaskResult | ErrorData: + task = await store.get_task(params.taskId) + assert task is not None, f"Test setup error: task {params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + async def get_task_result_handler( + context: RequestContext[ClientSession, None], + params: GetTaskPayloadRequestParams, + ) -> GetTaskPayloadResult | ErrorData: + result = await store.get_result(params.taskId) + assert result is not None, f"Test setup error: result for {params.taskId} should exist" + assert isinstance(result, ElicitResult) + return GetTaskPayloadResult(**result.model_dump()) + + task_handlers = ExperimentalTaskHandlers( + augmented_elicitation=task_augmented_elicitation_callback, + get_task=get_task_handler, + get_task_result=get_task_result_handler, + ) + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + background_tg[0] = tg + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + experimental_task_handlers=task_handlers, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Step 1: Server sends task-augmented ElicitRequest + typed_request = ElicitRequest( + params=ElicitRequestFormParams( + message="What is your name?", + requestedSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + task=TaskMetadata(ttl=60000), + ) + ) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-elicit", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + # Step 2: Client responds with CreateTaskResult + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCResponse) + + task_result = CreateTaskResult.model_validate(response.result) + task_id = task_result.task.taskId + assert task_id == created_task_id[0] + + # Step 3: Wait for background elicitation + await elicitation_completed.wait() + + # Step 4: Server polls task status + typed_poll = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id)) + poll_request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-poll", + **typed_poll.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + + poll_response_msg = await client_streams.server_receive.receive() + poll_response = poll_response_msg.message.root + assert isinstance(poll_response, types.JSONRPCResponse) + + status = GetTaskResult.model_validate(poll_response.result) + assert status.status == "completed" + + # Step 5: Server gets result + typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id)) + result_request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-result", + **typed_result_req.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + + result_response_msg = await client_streams.server_receive.receive() + result_response = result_response_msg.message.root + assert isinstance(result_response, types.JSONRPCResponse) + + # Verify the elicitation result + assert isinstance(result_response.result, dict) + assert result_response.result["action"] == "accept" + assert result_response.result["content"] == {"name": "Test User"} + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_task_request(client_streams: ClientTestStreams) -> None: + """Test that client returns error when no handler is registered for task request.""" + with anyio.fail_after(10): + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = GetTaskRequest(params=GetTaskRequestParams(taskId="nonexistent")) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-unhandled", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCError) + assert ( + "not supported" in response.error.message.lower() + or "method not found" in response.error.message.lower() + ) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_task_result_request(client_streams: ClientTestStreams) -> None: + """Test that client returns error for unhandled tasks/result request.""" + with anyio.fail_after(10): + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId="nonexistent")) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-result", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCError) + assert "not supported" in response.error.message.lower() + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_list_tasks_request(client_streams: ClientTestStreams) -> None: + """Test that client returns error for unhandled tasks/list request.""" + with anyio.fail_after(10): + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = ListTasksRequest() + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-list", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCError) + assert "not supported" in response.error.message.lower() + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_cancel_task_request(client_streams: ClientTestStreams) -> None: + """Test that client returns error for unhandled tasks/cancel request.""" + with anyio.fail_after(10): + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + typed_request = CancelTaskRequest(params=CancelTaskRequestParams(taskId="nonexistent")) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-cancel", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCError) + assert "not supported" in response.error.message.lower() + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_task_augmented_sampling(client_streams: ClientTestStreams) -> None: + """Test that client returns error for task-augmented sampling without handler.""" + with anyio.fail_after(10): + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + # No task handlers provided - uses defaults + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Send task-augmented sampling request + typed_request = CreateMessageRequest( + params=CreateMessageRequestParams( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + maxTokens=100, + task=TaskMetadata(ttl=60000), + ) + ) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-sampling", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCError) + assert "not supported" in response.error.message.lower() + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_client_returns_error_for_unhandled_task_augmented_elicitation( + client_streams: ClientTestStreams, +) -> None: + """Test that client returns error for task-augmented elicitation without handler.""" + with anyio.fail_after(10): + client_ready = anyio.Event() + + async with anyio.create_task_group() as tg: + + async def run_client() -> None: + # No task handlers provided - uses defaults + async with ClientSession( + client_streams.client_receive, + client_streams.client_send, + message_handler=_default_message_handler, + ): + client_ready.set() + await anyio.sleep_forever() + + tg.start_soon(run_client) + await client_ready.wait() + + # Send task-augmented elicitation request + typed_request = ElicitRequest( + params=ElicitRequestFormParams( + message="What is your name?", + requestedSchema={"type": "object", "properties": {"name": {"type": "string"}}}, + task=TaskMetadata(ttl=60000), + ) + ) + request = types.JSONRPCRequest( + jsonrpc="2.0", + id="req-elicit", + **typed_request.model_dump(by_alias=True), + ) + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + + response_msg = await client_streams.server_receive.receive() + response = response_msg.message.root + assert isinstance(response, types.JSONRPCError) + assert "not supported" in response.error.message.lower() + + tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/client/test_poll_task.py b/tests/experimental/tasks/client/test_poll_task.py new file mode 100644 index 0000000000..8275dc668e --- /dev/null +++ b/tests/experimental/tasks/client/test_poll_task.py @@ -0,0 +1,121 @@ +"""Tests for poll_task async iterator.""" + +from collections.abc import Callable, Coroutine +from datetime import datetime, timezone +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from mcp.client.experimental.tasks import ExperimentalClientFeatures +from mcp.types import GetTaskResult, TaskStatus + + +def make_task_result( + status: TaskStatus = "working", + poll_interval: int = 0, + task_id: str = "test-task", + status_message: str | None = None, +) -> GetTaskResult: + """Create GetTaskResult with sensible defaults.""" + now = datetime.now(timezone.utc) + return GetTaskResult( + taskId=task_id, + status=status, + statusMessage=status_message, + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=poll_interval, + ) + + +def make_status_sequence( + *statuses: TaskStatus, + task_id: str = "test-task", +) -> Callable[[str], Coroutine[Any, Any, GetTaskResult]]: + """Create mock get_task that returns statuses in sequence.""" + status_iter = iter(statuses) + + async def mock_get_task(tid: str) -> GetTaskResult: + return make_task_result(status=next(status_iter), task_id=tid) + + return mock_get_task + + +@pytest.fixture +def mock_session() -> AsyncMock: + return AsyncMock() + + +@pytest.fixture +def features(mock_session: AsyncMock) -> ExperimentalClientFeatures: + return ExperimentalClientFeatures(mock_session) + + +@pytest.mark.anyio +async def test_poll_task_yields_until_completed(features: ExperimentalClientFeatures) -> None: + """poll_task yields each status until terminal.""" + features.get_task = make_status_sequence("working", "working", "completed") # type: ignore[method-assign] + + statuses = [s.status async for s in features.poll_task("test-task")] + + assert statuses == ["working", "working", "completed"] + + +@pytest.mark.anyio +@pytest.mark.parametrize("terminal_status", ["completed", "failed", "cancelled"]) +async def test_poll_task_exits_on_terminal(features: ExperimentalClientFeatures, terminal_status: TaskStatus) -> None: + """poll_task exits immediately when task is already terminal.""" + features.get_task = make_status_sequence(terminal_status) # type: ignore[method-assign] + + statuses = [s.status async for s in features.poll_task("test-task")] + + assert statuses == [terminal_status] + + +@pytest.mark.anyio +async def test_poll_task_continues_through_input_required(features: ExperimentalClientFeatures) -> None: + """poll_task yields input_required and continues (non-terminal).""" + features.get_task = make_status_sequence("working", "input_required", "working", "completed") # type: ignore[method-assign] + + statuses = [s.status async for s in features.poll_task("test-task")] + + assert statuses == ["working", "input_required", "working", "completed"] + + +@pytest.mark.anyio +async def test_poll_task_passes_task_id(features: ExperimentalClientFeatures) -> None: + """poll_task passes correct task_id to get_task.""" + received_ids: list[str] = [] + + async def mock_get_task(task_id: str) -> GetTaskResult: + received_ids.append(task_id) + return make_task_result(status="completed", task_id=task_id) + + features.get_task = mock_get_task # type: ignore[method-assign] + + _ = [s async for s in features.poll_task("my-task-123")] + + assert received_ids == ["my-task-123"] + + +@pytest.mark.anyio +async def test_poll_task_yields_full_result(features: ExperimentalClientFeatures) -> None: + """poll_task yields complete GetTaskResult objects.""" + + async def mock_get_task(task_id: str) -> GetTaskResult: + return make_task_result( + status="completed", + task_id=task_id, + status_message="All done!", + ) + + features.get_task = mock_get_task # type: ignore[method-assign] + + results = [r async for r in features.poll_task("test-task")] + + assert len(results) == 1 + assert results[0].status == "completed" + assert results[0].statusMessage == "All done!" + assert results[0].taskId == "test-task" diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py new file mode 100644 index 0000000000..24c8891def --- /dev/null +++ b/tests/experimental/tasks/client/test_tasks.py @@ -0,0 +1,483 @@ +"""Tests for the experimental client task methods (session.experimental).""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks.helpers import task_execution +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + CancelTaskRequest, + CancelTaskResult, + ClientRequest, + ClientResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, + Tool, +) + + +@dataclass +class AppContext: + """Application context passed via lifespan_context.""" + + task_group: TaskGroup + store: InMemoryTaskStore + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + # Note: We bypass the normal lifespan mechanism + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + raise NotImplementedError + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + assert task is not None, f"Test setup error: task {request.params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Wait for task to complete + await app_context.task_done_events[task_id].wait() + + # Use session.experimental to get task status + task_status = await client_session.experimental.get_task(task_id) + + assert task_status.taskId == task_id + assert task_status.status == "completed" + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_session_experimental_get_task_result() -> None: + """Test session.experimental.get_task_result() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Task result content")]) + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + raise NotImplementedError + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + assert result is not None, f"Test setup error: result for {request.params.taskId} should exist" + assert isinstance(result, CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Wait for task to complete + await app_context.task_done_events[task_id].wait() + + # Use TaskClient to get task result + task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_session_experimental_list_tasks() -> None: + """Test TaskClient.list_tasks() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + raise NotImplementedError + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + app = server.request_context.lifespan_context + tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + return ListTasksResult(tasks=tasks_list, nextCursor=next_cursor) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create two tasks + for _ in range(2): + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + await app_context.task_done_events[create_result.task.taskId].wait() + + # Use TaskClient to list tasks + list_result = await client_session.experimental.list_tasks() + + assert len(list_result.tasks) == 2 + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_session_experimental_cancel_task() -> None: + """Test TaskClient.cancel_task() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + # Don't start any work - task stays in "working" status + return CreateTaskResult(task=task) + + raise NotImplementedError + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + assert task is not None, f"Test setup error: task {request.params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + assert task is not None, f"Test setup error: task {request.params.taskId} should exist" + await app.store.update_task(request.params.taskId, status="cancelled") + # CancelTaskResult extends Task, so we need to return the updated task info + updated_task = await app.store.get_task(request.params.taskId) + assert updated_task is not None + return CancelTaskResult( + taskId=updated_task.taskId, + status=updated_task.status, + createdAt=updated_task.createdAt, + lastUpdatedAt=updated_task.lastUpdatedAt, + ttl=updated_task.ttl, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task (but don't complete it) + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Verify task is working + status_before = await client_session.experimental.get_task(task_id) + assert status_before.status == "working" + + # Cancel the task + await client_session.experimental.cancel_task(task_id) + + # Verify task is cancelled + status_after = await client_session.experimental.get_task(task_id) + assert status_after.status == "cancelled" + + tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/server/__init__.py b/tests/experimental/tasks/server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py new file mode 100644 index 0000000000..2f09ff1540 --- /dev/null +++ b/tests/experimental/tasks/server/test_context.py @@ -0,0 +1,183 @@ +"""Tests for TaskContext and helper functions.""" + +import pytest + +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.helpers import create_task_state, task_execution +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.types import CallToolResult, TaskMetadata, TextContent + + +@pytest.mark.anyio +async def test_task_context_properties() -> None: + """Test TaskContext basic properties.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store) + + assert ctx.task_id == task.taskId + assert ctx.task.taskId == task.taskId + assert ctx.task.status == "working" + assert ctx.is_cancelled is False + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_update_status() -> None: + """Test TaskContext.update_status.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store) + + await ctx.update_status("Processing step 1...") + + # Check status message was updated + updated = await store.get_task(task.taskId) + assert updated is not None + assert updated.statusMessage == "Processing step 1..." + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_complete() -> None: + """Test TaskContext.complete.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store) + + result = CallToolResult(content=[TextContent(type="text", text="Done!")]) + await ctx.complete(result) + + # Check task status + updated = await store.get_task(task.taskId) + assert updated is not None + assert updated.status == "completed" + + # Check result is stored + stored_result = await store.get_result(task.taskId) + assert stored_result is not None + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_fail() -> None: + """Test TaskContext.fail.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store) + + await ctx.fail("Something went wrong!") + + # Check task status + updated = await store.get_task(task.taskId) + assert updated is not None + assert updated.status == "failed" + assert updated.statusMessage == "Something went wrong!" + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_cancellation() -> None: + """Test TaskContext cancellation request.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store) + + assert ctx.is_cancelled is False + + ctx.request_cancellation() + + assert ctx.is_cancelled is True + + store.cleanup() + + +def test_create_task_state_generates_id() -> None: + """create_task_state generates a unique task ID when none provided.""" + task1 = create_task_state(TaskMetadata(ttl=60000)) + task2 = create_task_state(TaskMetadata(ttl=60000)) + + assert task1.taskId != task2.taskId + + +def test_create_task_state_uses_provided_id() -> None: + """create_task_state uses the provided task ID.""" + task = create_task_state(TaskMetadata(ttl=60000), task_id="my-task-123") + assert task.taskId == "my-task-123" + + +def test_create_task_state_null_ttl() -> None: + """create_task_state handles null TTL.""" + task = create_task_state(TaskMetadata(ttl=None)) + assert task.ttl is None + + +def test_create_task_state_has_created_at() -> None: + """create_task_state sets createdAt timestamp.""" + task = create_task_state(TaskMetadata(ttl=60000)) + assert task.createdAt is not None + + +@pytest.mark.anyio +async def test_task_execution_provides_context() -> None: + """task_execution provides a TaskContext for the task.""" + store = InMemoryTaskStore() + await store.create_task(TaskMetadata(ttl=60000), task_id="exec-test-1") + + async with task_execution("exec-test-1", store) as ctx: + assert ctx.task_id == "exec-test-1" + assert ctx.task.status == "working" + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_execution_auto_fails_on_exception() -> None: + """task_execution automatically fails task on unhandled exception.""" + store = InMemoryTaskStore() + await store.create_task(TaskMetadata(ttl=60000), task_id="exec-fail-1") + + async with task_execution("exec-fail-1", store): + raise RuntimeError("Oops!") + + # Task should be failed + failed_task = await store.get_task("exec-fail-1") + assert failed_task is not None + assert failed_task.status == "failed" + assert "Oops!" in (failed_task.statusMessage or "") + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_execution_doesnt_fail_if_already_terminal() -> None: + """task_execution doesn't re-fail if task already terminal.""" + store = InMemoryTaskStore() + await store.create_task(TaskMetadata(ttl=60000), task_id="exec-term-1") + + async with task_execution("exec-term-1", store) as ctx: + # Complete the task first + await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) + # Then raise - shouldn't change status + raise RuntimeError("This shouldn't matter") + + # Task should remain completed + final_task = await store.get_task("exec-term-1") + assert final_task is not None + assert final_task.status == "completed" + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_execution_not_found() -> None: + """task_execution raises ValueError for non-existent task.""" + store = InMemoryTaskStore() + + with pytest.raises(ValueError, match="not found"): + async with task_execution("nonexistent", store): + ... diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py new file mode 100644 index 0000000000..ba61dfcead --- /dev/null +++ b/tests/experimental/tasks/server/test_integration.py @@ -0,0 +1,357 @@ +"""End-to-end integration tests for tasks functionality. + +These tests demonstrate the full task lifecycle: +1. Client sends task-augmented request (tools/call with task metadata) +2. Server creates task and returns CreateTaskResult immediately +3. Background work executes (using task_execution context manager) +4. Client polls with tasks/get +5. Client retrieves result with tasks/result +""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks.helpers import task_execution +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + TASK_REQUIRED, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + ClientResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + + +@dataclass +class AppContext: + """Application context passed via lifespan_context.""" + + task_group: TaskGroup + store: InMemoryTaskStore + # Events to signal when tasks complete (for testing without sleeps) + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_task_lifecycle_with_task_execution() -> None: + """ + Test the complete task lifecycle using the task_execution pattern. + + This demonstrates the recommended way to implement task-augmented tools: + 1. Create task in store + 2. Spawn work using task_execution() context manager + 3. Return CreateTaskResult immediately + 4. Work executes in background, auto-fails on exception + """ + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="process_data", + description="Process data asynchronously", + inputSchema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + if name == "process_data" and ctx.experimental.is_task: + # 1. Create task in store + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + # 2. Create event to signal completion (for testing) + done_event = Event() + app.task_done_events[task.taskId] = done_event + + # 3. Define work function using task_execution for safety + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Processing input...") + # Simulate work + input_value = arguments.get("input", "") + result_text = f"Processed: {input_value.upper()}" + await task_ctx.complete(CallToolResult(content=[TextContent(type="text", text=result_text)])) + # Signal completion + done_event.set() + + # 4. Spawn work in task group (from lifespan_context) + app.task_group.start_soon(do_work) + + # 5. Return CreateTaskResult immediately + return CreateTaskResult(task=task) + + raise NotImplementedError + + # Register task query handlers (delegate to store) + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + assert task is not None, f"Test setup error: task {request.params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + assert result is not None, f"Test setup error: result for {request.params.taskId} should exist" + assert isinstance(result, CallToolResult) + # Return as GetTaskPayloadResult (which accepts extra fields) + return GetTaskPayloadResult(**result.model_dump()) + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + raise NotImplementedError + + # Set up client-server communication + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no cover + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + # Create app context with task group and store + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # === Step 1: Send task-augmented tool call === + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + assert isinstance(create_result, CreateTaskResult) + assert create_result.task.status == "working" + task_id = create_result.task.taskId + + # === Step 2: Wait for task to complete === + await app_context.task_done_events[task_id].wait() + + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + assert task_status.taskId == task_id + assert task_status.status == "completed" + + # === Step 3: Retrieve the actual result === + task_result = await client_session.send_request( + ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))), + CallToolResult, + ) + + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed: HELLO WORLD" + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_task_auto_fails_on_exception() -> None: + """Test that task_execution automatically fails the task on unhandled exception.""" + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="failing_task", + description="A task that fails", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + if name == "failing_task" and ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + # Create event to signal completion (for testing) + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_failing_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("About to fail...") + raise RuntimeError("Something went wrong!") + # Note: complete() is never called, but task_execution + # will automatically call fail() due to the exception + # This line is reached because task_execution suppresses the exception + done_event.set() + + app.task_group.start_soon(do_failing_work) + return CreateTaskResult(task=task) + + raise NotImplementedError + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + assert task is not None, f"Test setup error: task {request.params.taskId} should exist" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no cover + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Send task request + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + task_id = create_result.task.taskId + + # Wait for task to complete (even though it fails) + await app_context.task_done_events[task_id].wait() + + # Check that task was auto-failed + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + assert task_status.status == "failed" + assert task_status.statusMessage == "Something went wrong!" + + tg.cancel_scope.cancel() diff --git a/tests/experimental/tasks/server/test_run_task_flow.py b/tests/experimental/tasks/server/test_run_task_flow.py new file mode 100644 index 0000000000..7f680beb66 --- /dev/null +++ b/tests/experimental/tasks/server/test_run_task_flow.py @@ -0,0 +1,538 @@ +""" +Tests for the simplified task API: enable_tasks() + run_task() + +This tests the recommended user flow: +1. server.experimental.enable_tasks() - one-line setup +2. ctx.experimental.run_task(work) - spawns work, returns CreateTaskResult +3. work function uses ServerTaskContext for elicit/create_message + +These are integration tests that verify the complete flow works end-to-end. +""" + +from typing import Any +from unittest.mock import Mock + +import anyio +import pytest +from anyio import Event + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.experimental.request_context import Experimental +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.experimental.task_support import TaskSupport +from mcp.server.lowlevel import NotificationOptions +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CancelTaskRequest, + CancelTaskResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + TextContent, + Tool, + ToolExecution, +) + + +@pytest.mark.anyio +async def test_run_task_basic_flow() -> None: + """ + Test the basic run_task flow without elicitation. + + 1. enable_tasks() sets up handlers + 2. Client calls tool with task field + 3. run_task() spawns work, returns CreateTaskResult + 4. Work completes in background + 5. Client polls and sees completed status + """ + server = Server("test-run-task") + + # One-line setup + server.experimental.enable_tasks() + + # Track when work completes and capture received meta + work_completed = Event() + received_meta: list[str | None] = [None] + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="simple_task", + description="A simple task", + inputSchema={"type": "object", "properties": {"input": {"type": "string"}}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + # Capture the meta from the request (if present) + if ctx.meta is not None and ctx.meta.model_extra: # pragma: no branch + received_meta[0] = ctx.meta.model_extra.get("custom_field") + + async def work(task: ServerTaskContext) -> CallToolResult: + await task.update_status("Working...") + input_val = arguments.get("input", "default") + result = CallToolResult(content=[TextContent(type="text", text=f"Processed: {input_val}")]) + work_completed.set() + return result + + return await ctx.experimental.run_task(work) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: + # Initialize + await client_session.initialize() + + # Call tool as task (with meta to test that code path) + result = await client_session.experimental.call_tool_as_task( + "simple_task", + {"input": "hello"}, + meta={"custom_field": "test_value"}, + ) + + # Should get CreateTaskResult + task_id = result.task.taskId + assert result.task.status == "working" + + # Wait for work to complete + with anyio.fail_after(5): + await work_completed.wait() + + # Poll until task status is completed + with anyio.fail_after(5): + while True: + task_status = await client_session.experimental.get_task(task_id) + if task_status.status == "completed": # pragma: no branch + break + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + # Verify the meta was passed through correctly + assert received_meta[0] == "test_value" + + +@pytest.mark.anyio +async def test_run_task_auto_fails_on_exception() -> None: + """ + Test that run_task automatically fails the task when work raises. + """ + server = Server("test-run-task-fail") + server.experimental.enable_tasks() + + work_failed = Event() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="failing_task", + description="A task that fails", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + work_failed.set() + raise RuntimeError("Something went wrong!") + + return await ctx.experimental.run_task(work) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def run_client() -> None: + async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: + await client_session.initialize() + + result = await client_session.experimental.call_tool_as_task("failing_task", {}) + task_id = result.task.taskId + + # Wait for work to fail + with anyio.fail_after(5): + await work_failed.wait() + + # Poll until task status is failed + with anyio.fail_after(5): + while True: + task_status = await client_session.experimental.get_task(task_id) + if task_status.status == "failed": # pragma: no branch + break + + assert "Something went wrong" in (task_status.statusMessage or "") + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + +@pytest.mark.anyio +async def test_enable_tasks_auto_registers_handlers() -> None: + """ + Test that enable_tasks() auto-registers get_task, list_tasks, cancel_task handlers. + """ + server = Server("test-enable-tasks") + + # Before enable_tasks, no task capabilities + caps_before = server.get_capabilities(NotificationOptions(), {}) + assert caps_before.tasks is None + + # Enable tasks + server.experimental.enable_tasks() + + # After enable_tasks, should have task capabilities + caps_after = server.get_capabilities(NotificationOptions(), {}) + assert caps_after.tasks is not None + assert caps_after.tasks.list is not None + assert caps_after.tasks.cancel is not None + + +@pytest.mark.anyio +async def test_enable_tasks_with_custom_store_and_queue() -> None: + """Test that enable_tasks() uses provided store and queue instead of defaults.""" + server = Server("test-custom-store-queue") + + # Create custom store and queue + custom_store = InMemoryTaskStore() + custom_queue = InMemoryTaskMessageQueue() + + # Enable tasks with custom implementations + task_support = server.experimental.enable_tasks(store=custom_store, queue=custom_queue) + + # Verify our custom implementations are used + assert task_support.store is custom_store + assert task_support.queue is custom_queue + + +@pytest.mark.anyio +async def test_enable_tasks_skips_default_handlers_when_custom_registered() -> None: + """Test that enable_tasks() doesn't override already-registered handlers.""" + server = Server("test-custom-handlers") + + # Register custom handlers BEFORE enable_tasks (never called, just for registration) + @server.experimental.get_task() + async def custom_get_task(req: GetTaskRequest) -> GetTaskResult: + raise NotImplementedError + + @server.experimental.get_task_result() + async def custom_get_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: + raise NotImplementedError + + @server.experimental.list_tasks() + async def custom_list_tasks(req: ListTasksRequest) -> ListTasksResult: + raise NotImplementedError + + @server.experimental.cancel_task() + async def custom_cancel_task(req: CancelTaskRequest) -> CancelTaskResult: + raise NotImplementedError + + # Now enable tasks - should NOT override our custom handlers + server.experimental.enable_tasks() + + # Verify our custom handlers are still registered (not replaced by defaults) + # The handlers dict should contain our custom handlers + assert GetTaskRequest in server.request_handlers + assert GetTaskPayloadRequest in server.request_handlers + assert ListTasksRequest in server.request_handlers + assert CancelTaskRequest in server.request_handlers + + +@pytest.mark.anyio +async def test_run_task_without_enable_tasks_raises() -> None: + """Test that run_task raises when enable_tasks() wasn't called.""" + experimental = Experimental( + task_metadata=None, + _client_capabilities=None, + _session=None, + _task_support=None, # Not enabled + ) + + async def work(task: ServerTaskContext) -> CallToolResult: + raise NotImplementedError + + with pytest.raises(RuntimeError, match="Task support not enabled"): + await experimental.run_task(work) + + +@pytest.mark.anyio +async def test_task_support_task_group_before_run_raises() -> None: + """Test that accessing task_group before run() raises RuntimeError.""" + task_support = TaskSupport.in_memory() + + with pytest.raises(RuntimeError, match="TaskSupport not running"): + _ = task_support.task_group + + +@pytest.mark.anyio +async def test_run_task_without_session_raises() -> None: + """Test that run_task raises when session is not available.""" + task_support = TaskSupport.in_memory() + + experimental = Experimental( + task_metadata=None, + _client_capabilities=None, + _session=None, # No session + _task_support=task_support, + ) + + async def work(task: ServerTaskContext) -> CallToolResult: + raise NotImplementedError + + with pytest.raises(RuntimeError, match="Session not available"): + await experimental.run_task(work) + + +@pytest.mark.anyio +async def test_run_task_without_task_metadata_raises() -> None: + """Test that run_task raises when request is not task-augmented.""" + task_support = TaskSupport.in_memory() + mock_session = Mock() + + experimental = Experimental( + task_metadata=None, # Not a task-augmented request + _client_capabilities=None, + _session=mock_session, + _task_support=task_support, + ) + + async def work(task: ServerTaskContext) -> CallToolResult: + raise NotImplementedError + + with pytest.raises(RuntimeError, match="Request is not task-augmented"): + await experimental.run_task(work) + + +@pytest.mark.anyio +async def test_run_task_with_model_immediate_response() -> None: + """Test that run_task includes model_immediate_response in CreateTaskResult._meta.""" + server = Server("test-run-task-immediate") + server.experimental.enable_tasks() + + work_completed = Event() + immediate_response_text = "Processing your request..." + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="task_with_immediate", + description="A task with immediate response", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + work_completed.set() + return CallToolResult(content=[TextContent(type="text", text="Done")]) + + return await ctx.experimental.run_task(work, model_immediate_response=immediate_response_text) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def run_client() -> None: + async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: + await client_session.initialize() + + result = await client_session.experimental.call_tool_as_task("task_with_immediate", {}) + + # Verify the immediate response is in _meta + assert result.meta is not None + assert "io.modelcontextprotocol/model-immediate-response" in result.meta + assert result.meta["io.modelcontextprotocol/model-immediate-response"] == immediate_response_text + + with anyio.fail_after(5): + await work_completed.wait() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + +@pytest.mark.anyio +async def test_run_task_doesnt_complete_if_already_terminal() -> None: + """Test that run_task doesn't auto-complete if work manually completed the task.""" + server = Server("test-already-complete") + server.experimental.enable_tasks() + + work_completed = Event() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="manual_complete_task", + description="A task that manually completes", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + # Manually complete the task before returning + manual_result = CallToolResult(content=[TextContent(type="text", text="Manually completed")]) + await task.complete(manual_result, notify=False) + work_completed.set() + # Return a different result - but it should be ignored since task is already terminal + return CallToolResult(content=[TextContent(type="text", text="This should be ignored")]) + + return await ctx.experimental.run_task(work) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def run_client() -> None: + async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: + await client_session.initialize() + + result = await client_session.experimental.call_tool_as_task("manual_complete_task", {}) + task_id = result.task.taskId + + with anyio.fail_after(5): + await work_completed.wait() + + # Poll until task status is completed + with anyio.fail_after(5): + while True: + status = await client_session.experimental.get_task(task_id) + if status.status == "completed": # pragma: no branch + break + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + +@pytest.mark.anyio +async def test_run_task_doesnt_fail_if_already_terminal() -> None: + """Test that run_task doesn't auto-fail if work manually failed/cancelled the task.""" + server = Server("test-already-failed") + server.experimental.enable_tasks() + + work_completed = Event() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="manual_cancel_task", + description="A task that manually cancels then raises", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + # Manually fail the task first + await task.fail("Manually failed", notify=False) + work_completed.set() + # Then raise - but the auto-fail should be skipped since task is already terminal + raise RuntimeError("This error should not change status") + + return await ctx.experimental.run_task(work) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def run_client() -> None: + async with ClientSession(server_to_client_receive, client_to_server_send) as client_session: + await client_session.initialize() + + result = await client_session.experimental.call_tool_as_task("manual_cancel_task", {}) + task_id = result.task.taskId + + with anyio.fail_after(5): + await work_completed.wait() + + # Poll until task status is failed + with anyio.fail_after(5): + while True: + status = await client_session.experimental.get_task(task_id) + if status.status == "failed": # pragma: no branch + break + + # Task should still be failed (from manual fail, not auto-fail from exception) + assert status.statusMessage == "Manually failed" # Not "This error should not change status" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py new file mode 100644 index 0000000000..7209ed412a --- /dev/null +++ b/tests/experimental/tasks/server/test_server.py @@ -0,0 +1,965 @@ +"""Tests for server-side task support (handlers, capabilities, integration).""" + +from datetime import datetime, timezone +from typing import Any + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.exceptions import McpError +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.response_router import ResponseRouter +from mcp.shared.session import RequestResponder +from mcp.types import ( + INVALID_REQUEST, + TASK_FORBIDDEN, + TASK_OPTIONAL, + TASK_REQUIRED, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + CancelTaskRequest, + CancelTaskRequestParams, + CancelTaskResult, + ClientRequest, + ClientResult, + ErrorData, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCResponse, + ListTasksRequest, + ListTasksResult, + ListToolsRequest, + ListToolsResult, + SamplingMessage, + ServerCapabilities, + ServerNotification, + ServerRequest, + ServerResult, + Task, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + + +@pytest.mark.anyio +async def test_list_tasks_handler() -> None: + """Test that experimental list_tasks handler works.""" + server = Server("test") + + now = datetime.now(timezone.utc) + test_tasks = [ + Task( + taskId="task-1", + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=1000, + ), + Task( + taskId="task-2", + status="completed", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=1000, + ), + ] + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=test_tasks) + + handler = server.request_handlers[ListTasksRequest] + request = ListTasksRequest(method="tasks/list") + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListTasksResult) + assert len(result.root.tasks) == 2 + assert result.root.tasks[0].taskId == "task-1" + assert result.root.tasks[1].taskId == "task-2" + + +@pytest.mark.anyio +async def test_get_task_handler() -> None: + """Test that experimental get_task handler works.""" + server = Server("test") + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + now = datetime.now(timezone.utc) + return GetTaskResult( + taskId=request.params.taskId, + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + pollInterval=1000, + ) + + handler = server.request_handlers[GetTaskRequest] + request = GetTaskRequest( + method="tasks/get", + params=GetTaskRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, GetTaskResult) + assert result.root.taskId == "test-task-123" + assert result.root.status == "working" + + +@pytest.mark.anyio +async def test_get_task_result_handler() -> None: + """Test that experimental get_task_result handler works.""" + server = Server("test") + + @server.experimental.get_task_result() + async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + return GetTaskPayloadResult() + + handler = server.request_handlers[GetTaskPayloadRequest] + request = GetTaskPayloadRequest( + method="tasks/result", + params=GetTaskPayloadRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, GetTaskPayloadResult) + + +@pytest.mark.anyio +async def test_cancel_task_handler() -> None: + """Test that experimental cancel_task handler works.""" + server = Server("test") + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + now = datetime.now(timezone.utc) + return CancelTaskResult( + taskId=request.params.taskId, + status="cancelled", + createdAt=now, + lastUpdatedAt=now, + ttl=60000, + ) + + handler = server.request_handlers[CancelTaskRequest] + request = CancelTaskRequest( + method="tasks/cancel", + params=CancelTaskRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, CancelTaskResult) + assert result.root.taskId == "test-task-123" + assert result.root.status == "cancelled" + + +@pytest.mark.anyio +async def test_server_capabilities_include_tasks() -> None: + """Test that server capabilities include tasks when handlers are registered.""" + server = Server("test") + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + raise NotImplementedError + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + raise NotImplementedError + + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + assert capabilities.tasks is not None + assert capabilities.tasks.list is not None + assert capabilities.tasks.cancel is not None + assert capabilities.tasks.requests is not None + assert capabilities.tasks.requests.tools is not None + + +@pytest.mark.anyio +async def test_server_capabilities_partial_tasks() -> None: + """Test capabilities with only some task handlers registered.""" + server = Server("test") + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + raise NotImplementedError + + # Only list_tasks registered, not cancel_task + + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + assert capabilities.tasks is not None + assert capabilities.tasks.list is not None + assert capabilities.tasks.cancel is None # Not registered + + +@pytest.mark.anyio +async def test_tool_with_task_execution_metadata() -> None: + """Test that tools can declare task execution mode.""" + server = Server("test") + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="quick_tool", + description="Fast tool", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport=TASK_FORBIDDEN), + ), + Tool( + name="long_tool", + description="Long running tool", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ), + Tool( + name="flexible_tool", + description="Can be either", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport=TASK_OPTIONAL), + ), + ] + + tools_handler = server.request_handlers[ListToolsRequest] + request = ListToolsRequest(method="tools/list") + result = await tools_handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListToolsResult) + tools = result.root.tools + + assert tools[0].execution is not None + assert tools[0].execution.taskSupport == TASK_FORBIDDEN + assert tools[1].execution is not None + assert tools[1].execution.taskSupport == TASK_REQUIRED + assert tools[2].execution is not None + assert tools[2].execution.taskSupport == TASK_OPTIONAL + + +@pytest.mark.anyio +async def test_task_metadata_in_call_tool_request() -> None: + """Test that task metadata is accessible via RequestContext when calling a tool.""" + server = Server("test") + captured_task_metadata: TaskMetadata | None = None + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="long_task", + description="A long running task", + inputSchema={"type": "object", "properties": {}}, + execution=ToolExecution(taskSupport="optional"), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + nonlocal captured_task_metadata + ctx = server.request_context + captured_task_metadata = ctx.experimental.task_metadata + return [TextContent(type="text", text="done")] + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async with anyio.create_task_group() as tg: + + async def handle_messages(): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + tg.start_soon(handle_messages) + await anyio.sleep_forever() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Call tool with task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CallToolResult, + ) + + tg.cancel_scope.cancel() + + assert captured_task_metadata is not None + assert captured_task_metadata.ttl == 60000 + + +@pytest.mark.anyio +async def test_task_metadata_is_task_property() -> None: + """Test that RequestContext.experimental.is_task works correctly.""" + server = Server("test") + is_task_values: list[bool] = [] + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + is_task_values.append(ctx.experimental.is_task) + return [TextContent(type="text", text="done")] + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async with anyio.create_task_group() as tg: + + async def handle_messages(): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + tg.start_soon(handle_messages) + await anyio.sleep_forever() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Call without task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}), + ) + ), + CallToolResult, + ) + + # Call with task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CallToolResult, + ) + + tg.cancel_scope.cancel() + + assert len(is_task_values) == 2 + assert is_task_values[0] is False # First call without task + assert is_task_values[1] is True # Second call with task + + +@pytest.mark.anyio +async def test_update_capabilities_no_handlers() -> None: + """Test that update_capabilities returns early when no task handlers are registered.""" + server = Server("test-no-handlers") + # Access experimental to initialize it, but don't register any task handlers + _ = server.experimental + + caps = server.get_capabilities(NotificationOptions(), {}) + + # Without any task handlers registered, tasks capability should be None + assert caps.tasks is None + + +@pytest.mark.anyio +async def test_default_task_handlers_via_enable_tasks() -> None: + """Test that enable_tasks() auto-registers working default handlers. + + This exercises the default handlers in lowlevel/experimental.py: + - _default_get_task (task not found) + - _default_get_task_result + - _default_list_tasks + - _default_cancel_task + """ + server = Server("test-default-handlers") + # Enable tasks with default handlers (no custom handlers registered) + task_support = server.experimental.enable_tasks() + store = task_support.store + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: ... # pragma: no branch + + async def run_server() -> None: + async with task_support.run(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + task_support.configure_session(server_session) + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task directly in the store for testing + task = await store.create_task(TaskMetadata(ttl=60000)) + + # Test list_tasks (default handler) + list_result = await client_session.send_request( + ClientRequest(ListTasksRequest()), + ListTasksResult, + ) + assert len(list_result.tasks) == 1 + assert list_result.tasks[0].taskId == task.taskId + + # Test get_task (default handler - found) + get_result = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task.taskId))), + GetTaskResult, + ) + assert get_result.taskId == task.taskId + assert get_result.status == "working" + + # Test get_task (default handler - not found path) + with pytest.raises(McpError, match="not found"): + await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId="nonexistent-task"))), + GetTaskResult, + ) + + # Create a completed task to test get_task_result + completed_task = await store.create_task(TaskMetadata(ttl=60000)) + await store.store_result( + completed_task.taskId, CallToolResult(content=[TextContent(type="text", text="Test result")]) + ) + await store.update_task(completed_task.taskId, status="completed") + + # Test get_task_result (default handler) + payload_result = await client_session.send_request( + ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=completed_task.taskId))), + GetTaskPayloadResult, + ) + # The result should have the related-task metadata + assert payload_result.meta is not None + assert "io.modelcontextprotocol/related-task" in payload_result.meta + + # Test cancel_task (default handler) + cancel_result = await client_session.send_request( + ClientRequest(CancelTaskRequest(params=CancelTaskRequestParams(taskId=task.taskId))), + CancelTaskResult, + ) + assert cancel_result.taskId == task.taskId + assert cancel_result.status == "cancelled" + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_build_elicit_form_request() -> None: + """Test that _build_elicit_form_request builds a proper elicitation request.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + # Test without task_id + request = server_session._build_elicit_form_request( + message="Test message", + requestedSchema={"type": "object", "properties": {"answer": {"type": "string"}}}, + ) + assert request.method == "elicitation/create" + assert request.params is not None + assert request.params["message"] == "Test message" + + # Test with related_task_id (adds related-task metadata) + request_with_task = server_session._build_elicit_form_request( + message="Task message", + requestedSchema={"type": "object"}, + related_task_id="test-task-123", + ) + assert request_with_task.method == "elicitation/create" + assert request_with_task.params is not None + assert "_meta" in request_with_task.params + assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] + assert ( + request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-123" + ) + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_build_elicit_url_request() -> None: + """Test that _build_elicit_url_request builds a proper URL mode elicitation request.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + # Test without related_task_id + request = server_session._build_elicit_url_request( + message="Please authorize with GitHub", + url="https://github.com/login/oauth/authorize", + elicitation_id="oauth-123", + ) + assert request.method == "elicitation/create" + assert request.params is not None + assert request.params["message"] == "Please authorize with GitHub" + assert request.params["url"] == "https://github.com/login/oauth/authorize" + assert request.params["elicitationId"] == "oauth-123" + assert request.params["mode"] == "url" + + # Test with related_task_id (adds related-task metadata) + request_with_task = server_session._build_elicit_url_request( + message="OAuth required", + url="https://example.com/oauth", + elicitation_id="oauth-456", + related_task_id="test-task-789", + ) + assert request_with_task.method == "elicitation/create" + assert request_with_task.params is not None + assert "_meta" in request_with_task.params + assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] + assert ( + request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] == "test-task-789" + ) + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_build_create_message_request() -> None: + """Test that _build_create_message_request builds a proper sampling request.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + messages = [ + SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), + ] + + # Test without task_id + request = server_session._build_create_message_request( + messages=messages, + max_tokens=100, + system_prompt="You are helpful", + ) + assert request.method == "sampling/createMessage" + assert request.params is not None + assert request.params["maxTokens"] == 100 + + # Test with related_task_id (adds related-task metadata) + request_with_task = server_session._build_create_message_request( + messages=messages, + max_tokens=50, + related_task_id="sampling-task-456", + ) + assert request_with_task.method == "sampling/createMessage" + assert request_with_task.params is not None + assert "_meta" in request_with_task.params + assert "io.modelcontextprotocol/related-task" in request_with_task.params["_meta"] + assert ( + request_with_task.params["_meta"]["io.modelcontextprotocol/related-task"]["taskId"] + == "sampling-task-456" + ) + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_send_message() -> None: + """Test that send_message sends a raw session message.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + # Create a test message + notification = JSONRPCNotification(jsonrpc="2.0", method="test/notification") + message = SessionMessage( + message=JSONRPCMessage(notification), + metadata=ServerMessageMetadata(related_request_id="test-req-1"), + ) + + # Send the message + await server_session.send_message(message) + + # Verify it was sent to the stream + received = await server_to_client_receive.receive() + assert isinstance(received.message.root, JSONRPCNotification) + assert received.message.root.method == "test/notification" + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_response_routing_success() -> None: + """Test that response routing works for success responses.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Track routed responses with event for synchronization + routed_responses: list[dict[str, Any]] = [] + response_received = anyio.Event() + + class TestRouter(ResponseRouter): + def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: + routed_responses.append({"id": request_id, "response": response}) + response_received.set() + return True # Handled + + def route_error(self, request_id: str | int, error: ErrorData) -> bool: + raise NotImplementedError + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + router = TestRouter() + server_session.add_response_router(router) + + # Simulate receiving a response from client + response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) + message = SessionMessage(message=JSONRPCMessage(response)) + + # Send from "client" side + await client_to_server_send.send(message) + + # Wait for response to be routed + with anyio.fail_after(5): + await response_received.wait() + + # Verify response was routed + assert len(routed_responses) == 1 + assert routed_responses[0]["id"] == "test-req-1" + assert routed_responses[0]["response"]["status"] == "ok" + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_response_routing_error() -> None: + """Test that error routing works for error responses.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Track routed errors with event for synchronization + routed_errors: list[dict[str, Any]] = [] + error_received = anyio.Event() + + class TestRouter(ResponseRouter): + def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: + raise NotImplementedError + + def route_error(self, request_id: str | int, error: ErrorData) -> bool: + routed_errors.append({"id": request_id, "error": error}) + error_received.set() + return True # Handled + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + router = TestRouter() + server_session.add_response_router(router) + + # Simulate receiving an error response from client + error_data = ErrorData(code=INVALID_REQUEST, message="Test error") + error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) + message = SessionMessage(message=JSONRPCMessage(error_response)) + + # Send from "client" side + await client_to_server_send.send(message) + + # Wait for error to be routed + with anyio.fail_after(5): + await error_received.wait() + + # Verify error was routed + assert len(routed_errors) == 1 + assert routed_errors[0]["id"] == "test-req-2" + assert routed_errors[0]["error"].message == "Test error" + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_response_routing_skips_non_matching_routers() -> None: + """Test that routing continues to next router when first doesn't match.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Track which routers were called + router_calls: list[str] = [] + response_received = anyio.Event() + + class NonMatchingRouter(ResponseRouter): + def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: + router_calls.append("non_matching_response") + return False # Doesn't handle it + + def route_error(self, request_id: str | int, error: ErrorData) -> bool: + raise NotImplementedError + + class MatchingRouter(ResponseRouter): + def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: + router_calls.append("matching_response") + response_received.set() + return True # Handles it + + def route_error(self, request_id: str | int, error: ErrorData) -> bool: + raise NotImplementedError + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + # Add non-matching router first, then matching router + server_session.add_response_router(NonMatchingRouter()) + server_session.add_response_router(MatchingRouter()) + + # Send a response - should skip first router and be handled by second + response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) + message = SessionMessage(message=JSONRPCMessage(response)) + await client_to_server_send.send(message) + + with anyio.fail_after(5): + await response_received.wait() + + # Verify both routers were called (first returned False, second returned True) + assert router_calls == ["non_matching_response", "matching_response"] + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() + + +@pytest.mark.anyio +async def test_error_routing_skips_non_matching_routers() -> None: + """Test that error routing continues to next router when first doesn't match.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + # Track which routers were called + router_calls: list[str] = [] + error_received = anyio.Event() + + class NonMatchingRouter(ResponseRouter): + def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: + raise NotImplementedError + + def route_error(self, request_id: str | int, error: ErrorData) -> bool: + router_calls.append("non_matching_error") + return False # Doesn't handle it + + class MatchingRouter(ResponseRouter): + def route_response(self, request_id: str | int, response: dict[str, Any]) -> bool: + raise NotImplementedError + + def route_error(self, request_id: str | int, error: ErrorData) -> bool: + router_calls.append("matching_error") + error_received.set() + return True # Handles it + + try: + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + # Add non-matching router first, then matching router + server_session.add_response_router(NonMatchingRouter()) + server_session.add_response_router(MatchingRouter()) + + # Send an error - should skip first router and be handled by second + error_data = ErrorData(code=INVALID_REQUEST, message="Test error") + error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) + message = SessionMessage(message=JSONRPCMessage(error_response)) + await client_to_server_send.send(message) + + with anyio.fail_after(5): + await error_received.wait() + + # Verify both routers were called (first returned False, second returned True) + assert router_calls == ["non_matching_error", "matching_error"] + finally: # pragma: no cover + await server_to_client_send.aclose() + await server_to_client_receive.aclose() + await client_to_server_send.aclose() + await client_to_server_receive.aclose() diff --git a/tests/experimental/tasks/server/test_server_task_context.py b/tests/experimental/tasks/server/test_server_task_context.py new file mode 100644 index 0000000000..3d6b16f482 --- /dev/null +++ b/tests/experimental/tasks/server/test_server_task_context.py @@ -0,0 +1,709 @@ +"""Tests for ServerTaskContext.""" + +import asyncio +from unittest.mock import AsyncMock, Mock + +import anyio +import pytest + +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue +from mcp.types import ( + CallToolResult, + ClientCapabilities, + ClientTasksCapability, + ClientTasksRequestsCapability, + Implementation, + InitializeRequestParams, + JSONRPCRequest, + SamplingMessage, + TaskMetadata, + TasksCreateElicitationCapability, + TasksCreateMessageCapability, + TasksElicitationCapability, + TasksSamplingCapability, + TextContent, +) + + +@pytest.mark.anyio +async def test_server_task_context_properties() -> None: + """Test ServerTaskContext property accessors.""" + store = InMemoryTaskStore() + mock_session = Mock() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-123") + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + ) + + assert ctx.task_id == "test-123" + assert ctx.task.taskId == "test-123" + assert ctx.is_cancelled is False + + store.cleanup() + + +@pytest.mark.anyio +async def test_server_task_context_request_cancellation() -> None: + """Test ServerTaskContext.request_cancellation().""" + store = InMemoryTaskStore() + mock_session = Mock() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + ) + + assert ctx.is_cancelled is False + ctx.request_cancellation() + assert ctx.is_cancelled is True + + store.cleanup() + + +@pytest.mark.anyio +async def test_server_task_context_update_status_with_notify() -> None: + """Test update_status sends notification when notify=True.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.send_notification = AsyncMock() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + ) + + await ctx.update_status("Working...", notify=True) + + mock_session.send_notification.assert_called_once() + store.cleanup() + + +@pytest.mark.anyio +async def test_server_task_context_update_status_without_notify() -> None: + """Test update_status skips notification when notify=False.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.send_notification = AsyncMock() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + ) + + await ctx.update_status("Working...", notify=False) + + mock_session.send_notification.assert_not_called() + store.cleanup() + + +@pytest.mark.anyio +async def test_server_task_context_complete_with_notify() -> None: + """Test complete sends notification when notify=True.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.send_notification = AsyncMock() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + ) + + result = CallToolResult(content=[TextContent(type="text", text="Done")]) + await ctx.complete(result, notify=True) + + mock_session.send_notification.assert_called_once() + store.cleanup() + + +@pytest.mark.anyio +async def test_server_task_context_fail_with_notify() -> None: + """Test fail sends notification when notify=True.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.send_notification = AsyncMock() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + ) + + await ctx.fail("Something went wrong", notify=True) + + mock_session.send_notification.assert_called_once() + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_raises_when_client_lacks_capability() -> None: + """Test that elicit() raises McpError when client doesn't support elicitation.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=False) + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + with pytest.raises(McpError) as exc_info: + await ctx.elicit(message="Test?", requestedSchema={"type": "object"}) + + assert "elicitation capability" in exc_info.value.error.message + mock_session.check_client_capability.assert_called_once() + store.cleanup() + + +@pytest.mark.anyio +async def test_create_message_raises_when_client_lacks_capability() -> None: + """Test that create_message() raises McpError when client doesn't support sampling.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=False) + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + with pytest.raises(McpError) as exc_info: + await ctx.create_message(messages=[], max_tokens=100) + + assert "sampling capability" in exc_info.value.error.message + mock_session.check_client_capability.assert_called_once() + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_raises_without_handler() -> None: + """Test that elicit() raises when handler is not provided.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=None, + ) + + with pytest.raises(RuntimeError, match="handler is required"): + await ctx.elicit(message="Test?", requestedSchema={"type": "object"}) + + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_url_raises_without_handler() -> None: + """Test that elicit_url() raises when handler is not provided.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=None, + ) + + with pytest.raises(RuntimeError, match="handler is required for elicit_url"): + await ctx.elicit_url( + message="Please authorize", + url="https://example.com/oauth", + elicitation_id="oauth-123", + ) + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_message_raises_without_handler() -> None: + """Test that create_message() raises when handler is not provided.""" + store = InMemoryTaskStore() + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=None, + ) + + with pytest.raises(RuntimeError, match="handler is required"): + await ctx.create_message(messages=[], max_tokens=100) + + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_queues_request_and_waits_for_response() -> None: + """Test that elicit() queues request and waits for response.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + mock_session._build_elicit_form_request = Mock( + return_value=JSONRPCRequest( + jsonrpc="2.0", + id="test-req-1", + method="elicitation/create", + params={"message": "Test?", "_meta": {}}, + ) + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + elicit_result = None + + async def run_elicit() -> None: + nonlocal elicit_result + elicit_result = await ctx.elicit( + message="Test?", + requestedSchema={"type": "object"}, + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_elicit) + + # Wait for request to be queued + await queue.wait_for_message(task.taskId) + + # Verify task is in input_required status + updated_task = await store.get_task(task.taskId) + assert updated_task is not None + assert updated_task.status == "input_required" + + # Dequeue and simulate response + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + + # Resolve with mock elicitation response + msg.resolver.set_result({"action": "accept", "content": {"name": "Alice"}}) + + # Verify result + assert elicit_result is not None + assert elicit_result.action == "accept" + assert elicit_result.content == {"name": "Alice"} + + # Verify task is back to working + final_task = await store.get_task(task.taskId) + assert final_task is not None + assert final_task.status == "working" + + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_url_queues_request_and_waits_for_response() -> None: + """Test that elicit_url() queues request and waits for response.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + mock_session._build_elicit_url_request = Mock( + return_value=JSONRPCRequest( + jsonrpc="2.0", + id="test-url-req-1", + method="elicitation/create", + params={"message": "Authorize", "url": "https://example.com/", "elicitationId": "123", "mode": "url"}, + ) + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + elicit_result = None + + async def run_elicit_url() -> None: + nonlocal elicit_result + elicit_result = await ctx.elicit_url( + message="Authorize", + url="https://example.com/oauth", + elicitation_id="oauth-123", + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_elicit_url) + + # Wait for request to be queued + await queue.wait_for_message(task.taskId) + + # Verify task is in input_required status + updated_task = await store.get_task(task.taskId) + assert updated_task is not None + assert updated_task.status == "input_required" + + # Dequeue and simulate response + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + + # Resolve with mock elicitation response (URL mode just returns action) + msg.resolver.set_result({"action": "accept"}) + + # Verify result + assert elicit_result is not None + assert elicit_result.action == "accept" + + # Verify task is back to working + final_task = await store.get_task(task.taskId) + assert final_task is not None + assert final_task.status == "working" + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_message_queues_request_and_waits_for_response() -> None: + """Test that create_message() queues request and waits for response.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + mock_session._build_create_message_request = Mock( + return_value=JSONRPCRequest( + jsonrpc="2.0", + id="test-req-2", + method="sampling/createMessage", + params={"messages": [], "maxTokens": 100, "_meta": {}}, + ) + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + sampling_result = None + + async def run_sampling() -> None: + nonlocal sampling_result + sampling_result = await ctx.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + max_tokens=100, + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_sampling) + + # Wait for request to be queued + await queue.wait_for_message(task.taskId) + + # Verify task is in input_required status + updated_task = await store.get_task(task.taskId) + assert updated_task is not None + assert updated_task.status == "input_required" + + # Dequeue and simulate response + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + + # Resolve with mock sampling response + msg.resolver.set_result( + { + "role": "assistant", + "content": {"type": "text", "text": "Hello back!"}, + "model": "test-model", + "stopReason": "endTurn", + } + ) + + # Verify result + assert sampling_result is not None + assert sampling_result.role == "assistant" + assert sampling_result.model == "test-model" + + # Verify task is back to working + final_task = await store.get_task(task.taskId) + assert final_task is not None + assert final_task.status == "working" + + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_restores_status_on_cancellation() -> None: + """Test that elicit() restores task status to working when cancelled.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + mock_session._build_elicit_form_request = Mock( + return_value=JSONRPCRequest( + jsonrpc="2.0", + id="test-req-cancel", + method="elicitation/create", + params={"message": "Test?", "_meta": {}}, + ) + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + cancelled_error_raised = False + + async with anyio.create_task_group() as tg: + + async def do_elicit() -> None: + nonlocal cancelled_error_raised + try: + await ctx.elicit( + message="Test?", + requestedSchema={"type": "object"}, + ) + except anyio.get_cancelled_exc_class(): + cancelled_error_raised = True + # Don't re-raise - let the test continue + + tg.start_soon(do_elicit) + + # Wait for request to be queued + await queue.wait_for_message(task.taskId) + + # Verify task is in input_required status + updated_task = await store.get_task(task.taskId) + assert updated_task is not None + assert updated_task.status == "input_required" + + # Get the queued message and set cancellation exception on its resolver + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + + # Trigger cancellation by setting exception (use asyncio.CancelledError directly) + msg.resolver.set_exception(asyncio.CancelledError()) + + # Verify task is back to working after cancellation + final_task = await store.get_task(task.taskId) + assert final_task is not None + assert final_task.status == "working" + assert cancelled_error_raised + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_message_restores_status_on_cancellation() -> None: + """Test that create_message() restores task status to working when cancelled.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + handler = TaskResultHandler(store, queue) + task = await store.create_task(TaskMetadata(ttl=60000)) + + mock_session = Mock() + mock_session.check_client_capability = Mock(return_value=True) + mock_session._build_create_message_request = Mock( + return_value=JSONRPCRequest( + jsonrpc="2.0", + id="test-req-cancel-2", + method="sampling/createMessage", + params={"messages": [], "maxTokens": 100, "_meta": {}}, + ) + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=handler, + ) + + cancelled_error_raised = False + + async with anyio.create_task_group() as tg: + + async def do_sampling() -> None: + nonlocal cancelled_error_raised + try: + await ctx.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + max_tokens=100, + ) + except anyio.get_cancelled_exc_class(): + cancelled_error_raised = True + # Don't re-raise + + tg.start_soon(do_sampling) + + # Wait for request to be queued + await queue.wait_for_message(task.taskId) + + # Verify task is in input_required status + updated_task = await store.get_task(task.taskId) + assert updated_task is not None + assert updated_task.status == "input_required" + + # Get the queued message and set cancellation exception on its resolver + msg = await queue.dequeue(task.taskId) + assert msg is not None + assert msg.resolver is not None + + # Trigger cancellation by setting exception (use asyncio.CancelledError directly) + msg.resolver.set_exception(asyncio.CancelledError()) + + # Verify task is back to working after cancellation + final_task = await store.get_task(task.taskId) + assert final_task is not None + assert final_task.status == "working" + assert cancelled_error_raised + + store.cleanup() + + +@pytest.mark.anyio +async def test_elicit_as_task_raises_without_handler() -> None: + """Test that elicit_as_task() raises when handler is not provided.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + # Create mock session with proper client capabilities + mock_session = Mock() + mock_session.client_params = InitializeRequestParams( + protocolVersion="2025-01-01", + capabilities=ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + ), + clientInfo=Implementation(name="test", version="1.0"), + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=None, + ) + + with pytest.raises(RuntimeError, match="handler is required for elicit_as_task"): + await ctx.elicit_as_task(message="Test?", requestedSchema={"type": "object"}) + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_message_as_task_raises_without_handler() -> None: + """Test that create_message_as_task() raises when handler is not provided.""" + store = InMemoryTaskStore() + queue = InMemoryTaskMessageQueue() + task = await store.create_task(TaskMetadata(ttl=60000)) + + # Create mock session with proper client capabilities + mock_session = Mock() + mock_session.client_params = InitializeRequestParams( + protocolVersion="2025-01-01", + capabilities=ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + ), + clientInfo=Implementation(name="test", version="1.0"), + ) + + ctx = ServerTaskContext( + task=task, + store=store, + session=mock_session, + queue=queue, + handler=None, + ) + + with pytest.raises(RuntimeError, match="handler is required for create_message_as_task"): + await ctx.create_message_as_task( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + max_tokens=100, + ) + + store.cleanup() diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py new file mode 100644 index 0000000000..2eac31dfe6 --- /dev/null +++ b/tests/experimental/tasks/server/test_store.py @@ -0,0 +1,406 @@ +"""Tests for InMemoryTaskStore.""" + +from collections.abc import AsyncIterator +from datetime import datetime, timedelta, timezone + +import pytest + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.helpers import cancel_task +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.types import INVALID_PARAMS, CallToolResult, TaskMetadata, TextContent + + +@pytest.fixture +async def store() -> AsyncIterator[InMemoryTaskStore]: + """Provide a clean InMemoryTaskStore for each test with automatic cleanup.""" + store = InMemoryTaskStore() + yield store + store.cleanup() + + +@pytest.mark.anyio +async def test_create_and_get(store: InMemoryTaskStore) -> None: + """Test InMemoryTaskStore create and get operations.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + assert task.taskId is not None + assert task.status == "working" + assert task.ttl == 60000 + + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.taskId == task.taskId + assert retrieved.status == "working" + + +@pytest.mark.anyio +async def test_create_with_custom_id(store: InMemoryTaskStore) -> None: + """Test InMemoryTaskStore create with custom task ID.""" + task = await store.create_task( + metadata=TaskMetadata(ttl=60000), + task_id="my-custom-id", + ) + + assert task.taskId == "my-custom-id" + assert task.status == "working" + + retrieved = await store.get_task("my-custom-id") + assert retrieved is not None + assert retrieved.taskId == "my-custom-id" + + +@pytest.mark.anyio +async def test_create_duplicate_id_raises(store: InMemoryTaskStore) -> None: + """Test that creating a task with duplicate ID raises.""" + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + + with pytest.raises(ValueError, match="already exists"): + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + + +@pytest.mark.anyio +async def test_get_nonexistent_returns_none(store: InMemoryTaskStore) -> None: + """Test that getting a nonexistent task returns None.""" + retrieved = await store.get_task("nonexistent") + assert retrieved is None + + +@pytest.mark.anyio +async def test_update_status(store: InMemoryTaskStore) -> None: + """Test InMemoryTaskStore status updates.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + updated = await store.update_task(task.taskId, status="completed", status_message="All done!") + + assert updated.status == "completed" + assert updated.statusMessage == "All done!" + + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.status == "completed" + assert retrieved.statusMessage == "All done!" + + +@pytest.mark.anyio +async def test_update_nonexistent_raises(store: InMemoryTaskStore) -> None: + """Test that updating a nonexistent task raises.""" + with pytest.raises(ValueError, match="not found"): + await store.update_task("nonexistent", status="completed") + + +@pytest.mark.anyio +async def test_store_and_get_result(store: InMemoryTaskStore) -> None: + """Test InMemoryTaskStore result storage and retrieval.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Store result + result = CallToolResult(content=[TextContent(type="text", text="Result data")]) + await store.store_result(task.taskId, result) + + # Retrieve result + retrieved_result = await store.get_result(task.taskId) + assert retrieved_result == result + + +@pytest.mark.anyio +async def test_get_result_nonexistent_returns_none(store: InMemoryTaskStore) -> None: + """Test that getting result for nonexistent task returns None.""" + result = await store.get_result("nonexistent") + assert result is None + + +@pytest.mark.anyio +async def test_get_result_no_result_returns_none(store: InMemoryTaskStore) -> None: + """Test that getting result when none stored returns None.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + result = await store.get_result(task.taskId) + assert result is None + + +@pytest.mark.anyio +async def test_list_tasks(store: InMemoryTaskStore) -> None: + """Test InMemoryTaskStore list operation.""" + # Create multiple tasks + for _ in range(3): + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + tasks, next_cursor = await store.list_tasks() + assert len(tasks) == 3 + assert next_cursor is None # Less than page size + + +@pytest.mark.anyio +async def test_list_tasks_pagination() -> None: + """Test InMemoryTaskStore pagination.""" + # Needs custom page_size, can't use fixture + store = InMemoryTaskStore(page_size=2) + + # Create 5 tasks + for _ in range(5): + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # First page + tasks, next_cursor = await store.list_tasks() + assert len(tasks) == 2 + assert next_cursor is not None + + # Second page + tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + assert len(tasks) == 2 + assert next_cursor is not None + + # Third page (last) + tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + assert len(tasks) == 1 + assert next_cursor is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks_invalid_cursor(store: InMemoryTaskStore) -> None: + """Test that invalid cursor raises.""" + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + with pytest.raises(ValueError, match="Invalid cursor"): + await store.list_tasks(cursor="invalid-cursor") + + +@pytest.mark.anyio +async def test_delete_task(store: InMemoryTaskStore) -> None: + """Test InMemoryTaskStore delete operation.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + deleted = await store.delete_task(task.taskId) + assert deleted is True + + retrieved = await store.get_task(task.taskId) + assert retrieved is None + + # Delete non-existent + deleted = await store.delete_task(task.taskId) + assert deleted is False + + +@pytest.mark.anyio +async def test_get_all_tasks_helper(store: InMemoryTaskStore) -> None: + """Test the get_all_tasks debugging helper.""" + await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + all_tasks = store.get_all_tasks() + assert len(all_tasks) == 2 + + +@pytest.mark.anyio +async def test_store_result_nonexistent_raises(store: InMemoryTaskStore) -> None: + """Test that storing result for nonexistent task raises ValueError.""" + result = CallToolResult(content=[TextContent(type="text", text="Result")]) + + with pytest.raises(ValueError, match="not found"): + await store.store_result("nonexistent-id", result) + + +@pytest.mark.anyio +async def test_create_task_with_null_ttl(store: InMemoryTaskStore) -> None: + """Test creating task with null TTL (never expires).""" + task = await store.create_task(metadata=TaskMetadata(ttl=None)) + + assert task.ttl is None + + # Task should persist (not expire) + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + + +@pytest.mark.anyio +async def test_task_expiration_cleanup(store: InMemoryTaskStore) -> None: + """Test that expired tasks are cleaned up lazily.""" + # Create a task with very short TTL + task = await store.create_task(metadata=TaskMetadata(ttl=1)) # 1ms TTL + + # Manually force the expiry to be in the past + stored = store._tasks.get(task.taskId) + assert stored is not None + stored.expires_at = datetime.now(timezone.utc) - timedelta(seconds=10) + + # Task should still exist in internal dict but be expired + assert task.taskId in store._tasks + + # Any access operation should clean up expired tasks + # list_tasks triggers cleanup + tasks, _ = await store.list_tasks() + + # Expired task should be cleaned up + assert task.taskId not in store._tasks + assert len(tasks) == 0 + + +@pytest.mark.anyio +async def test_task_with_null_ttl_never_expires(store: InMemoryTaskStore) -> None: + """Test that tasks with null TTL never expire during cleanup.""" + # Create task with null TTL + task = await store.create_task(metadata=TaskMetadata(ttl=None)) + + # Verify internal storage has no expiry + stored = store._tasks.get(task.taskId) + assert stored is not None + assert stored.expires_at is None + + # Access operations should NOT remove this task + await store.list_tasks() + await store.get_task(task.taskId) + + # Task should still exist + assert task.taskId in store._tasks + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + + +@pytest.mark.anyio +async def test_terminal_task_ttl_reset(store: InMemoryTaskStore) -> None: + """Test that TTL is reset when task enters terminal state.""" + # Create task with short TTL + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) # 60s + + # Get the initial expiry + stored = store._tasks.get(task.taskId) + assert stored is not None + initial_expiry = stored.expires_at + assert initial_expiry is not None + + # Update to terminal state (completed) + await store.update_task(task.taskId, status="completed") + + # Expiry should be reset to a new time (from now + TTL) + new_expiry = stored.expires_at + assert new_expiry is not None + assert new_expiry >= initial_expiry + + +@pytest.mark.anyio +async def test_terminal_status_transition_rejected(store: InMemoryTaskStore) -> None: + """Test that transitions from terminal states are rejected. + + Per spec: Terminal states (completed, failed, cancelled) MUST NOT + transition to any other status. + """ + # Test each terminal status + for terminal_status in ("completed", "failed", "cancelled"): + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Move to terminal state + await store.update_task(task.taskId, status=terminal_status) + + # Attempting to transition to any other status should raise + with pytest.raises(ValueError, match="Cannot transition from terminal status"): + await store.update_task(task.taskId, status="working") + + # Also test transitioning to another terminal state + other_terminal = "failed" if terminal_status != "failed" else "completed" + with pytest.raises(ValueError, match="Cannot transition from terminal status"): + await store.update_task(task.taskId, status=other_terminal) + + +@pytest.mark.anyio +async def test_terminal_status_allows_same_status(store: InMemoryTaskStore) -> None: + """Test that setting the same terminal status doesn't raise. + + This is not a transition, so it should be allowed (no-op). + """ + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="completed") + + # Setting the same status should not raise + updated = await store.update_task(task.taskId, status="completed") + assert updated.status == "completed" + + # Updating just the message should also work + updated = await store.update_task(task.taskId, status_message="Updated message") + assert updated.statusMessage == "Updated message" + + +@pytest.mark.anyio +async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> None: + """Test that wait_for_update raises for nonexistent task.""" + with pytest.raises(ValueError, match="not found"): + await store.wait_for_update("nonexistent-task-id") + + +@pytest.mark.anyio +async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None: + """Test cancel_task helper succeeds for a working task.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + assert task.status == "working" + + result = await cancel_task(store, task.taskId) + + assert result.taskId == task.taskId + assert result.status == "cancelled" + + # Verify store is updated + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.status == "cancelled" + + +@pytest.mark.anyio +async def test_cancel_task_rejects_nonexistent_task(store: InMemoryTaskStore) -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for nonexistent task.""" + with pytest.raises(McpError) as exc_info: + await cancel_task(store, "nonexistent-task-id") + + assert exc_info.value.error.code == INVALID_PARAMS + assert "not found" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_cancel_task_rejects_completed_task(store: InMemoryTaskStore) -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for completed task.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="completed") + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, task.taskId) + + assert exc_info.value.error.code == INVALID_PARAMS + assert "terminal state 'completed'" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_cancel_task_rejects_failed_task(store: InMemoryTaskStore) -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for failed task.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="failed") + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, task.taskId) + + assert exc_info.value.error.code == INVALID_PARAMS + assert "terminal state 'failed'" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_cancel_task_rejects_already_cancelled_task(store: InMemoryTaskStore) -> None: + """Test cancel_task raises McpError with INVALID_PARAMS for already cancelled task.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="cancelled") + + with pytest.raises(McpError) as exc_info: + await cancel_task(store, task.taskId) + + assert exc_info.value.error.code == INVALID_PARAMS + assert "terminal state 'cancelled'" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_cancel_task_succeeds_for_input_required_task(store: InMemoryTaskStore) -> None: + """Test cancel_task helper succeeds for a task in input_required status.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.update_task(task.taskId, status="input_required") + + result = await cancel_task(store, task.taskId) + + assert result.taskId == task.taskId + assert result.status == "cancelled" diff --git a/tests/experimental/tasks/server/test_task_result_handler.py b/tests/experimental/tasks/server/test_task_result_handler.py new file mode 100644 index 0000000000..db5b9edc70 --- /dev/null +++ b/tests/experimental/tasks/server/test_task_result_handler.py @@ -0,0 +1,354 @@ +"""Tests for TaskResultHandler.""" + +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, Mock + +import anyio +import pytest + +from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_REQUEST, + CallToolResult, + ErrorData, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + JSONRPCRequest, + TaskMetadata, + TextContent, +) + + +@pytest.fixture +async def store() -> AsyncIterator[InMemoryTaskStore]: + """Provide a clean store for each test.""" + s = InMemoryTaskStore() + yield s + s.cleanup() + + +@pytest.fixture +def queue() -> InMemoryTaskMessageQueue: + """Provide a clean queue for each test.""" + return InMemoryTaskMessageQueue() + + +@pytest.fixture +def handler(store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue) -> TaskResultHandler: + """Provide a handler for each test.""" + return TaskResultHandler(store, queue) + + +@pytest.mark.anyio +async def test_handle_returns_result_for_completed_task( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that handle() returns the stored result for a completed task.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + result = CallToolResult(content=[TextContent(type="text", text="Done!")]) + await store.store_result(task.taskId, result) + await store.update_task(task.taskId, status="completed") + + mock_session = Mock() + mock_session.send_message = AsyncMock() + + request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId)) + response = await handler.handle(request, mock_session, "req-1") + + assert response is not None + assert response.meta is not None + assert "io.modelcontextprotocol/related-task" in response.meta + + +@pytest.mark.anyio +async def test_handle_raises_for_nonexistent_task( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that handle() raises McpError for nonexistent task.""" + mock_session = Mock() + request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId="nonexistent")) + + with pytest.raises(McpError) as exc_info: + await handler.handle(request, mock_session, "req-1") + + assert "not found" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_handle_returns_empty_result_when_no_result_stored( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that handle() returns minimal result when task completed without stored result.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + await store.update_task(task.taskId, status="completed") + + mock_session = Mock() + mock_session.send_message = AsyncMock() + + request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId)) + response = await handler.handle(request, mock_session, "req-1") + + assert response is not None + assert response.meta is not None + assert "io.modelcontextprotocol/related-task" in response.meta + + +@pytest.mark.anyio +async def test_handle_delivers_queued_messages( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that handle() delivers queued messages before returning.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + + queued_msg = QueuedMessage( + type="notification", + message=JSONRPCRequest( + jsonrpc="2.0", + id="notif-1", + method="test/notification", + params={}, + ), + ) + await queue.enqueue(task.taskId, queued_msg) + await store.update_task(task.taskId, status="completed") + + sent_messages: list[SessionMessage] = [] + + async def track_send(msg: SessionMessage) -> None: + sent_messages.append(msg) + + mock_session = Mock() + mock_session.send_message = track_send + + request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId)) + await handler.handle(request, mock_session, "req-1") + + assert len(sent_messages) == 1 + + +@pytest.mark.anyio +async def test_handle_waits_for_task_completion( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that handle() waits for task to complete before returning.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + + mock_session = Mock() + mock_session.send_message = AsyncMock() + + request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task.taskId)) + result_holder: list[GetTaskPayloadResult | None] = [None] + + async def run_handle() -> None: + result_holder[0] = await handler.handle(request, mock_session, "req-1") + + async with anyio.create_task_group() as tg: + tg.start_soon(run_handle) + + # Wait for handler to start waiting (event gets created when wait starts) + while task.taskId not in store._update_events: + await anyio.sleep(0) + + await store.store_result(task.taskId, CallToolResult(content=[TextContent(type="text", text="Done")])) + await store.update_task(task.taskId, status="completed") + + assert result_holder[0] is not None + + +@pytest.mark.anyio +async def test_route_response_resolves_pending_request( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that route_response() resolves a pending request.""" + resolver: Resolver[dict[str, Any]] = Resolver() + handler._pending_requests["req-123"] = resolver + + result = handler.route_response("req-123", {"status": "ok"}) + + assert result is True + assert resolver.done() + assert await resolver.wait() == {"status": "ok"} + + +@pytest.mark.anyio +async def test_route_response_returns_false_for_unknown_request( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that route_response() returns False for unknown request ID.""" + result = handler.route_response("unknown-req", {"status": "ok"}) + assert result is False + + +@pytest.mark.anyio +async def test_route_response_returns_false_for_already_done_resolver( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that route_response() returns False if resolver already completed.""" + resolver: Resolver[dict[str, Any]] = Resolver() + resolver.set_result({"already": "done"}) + handler._pending_requests["req-123"] = resolver + + result = handler.route_response("req-123", {"new": "data"}) + + assert result is False + + +@pytest.mark.anyio +async def test_route_error_resolves_pending_request_with_exception( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that route_error() sets exception on pending request.""" + resolver: Resolver[dict[str, Any]] = Resolver() + handler._pending_requests["req-123"] = resolver + + error = ErrorData(code=INVALID_REQUEST, message="Something went wrong") + result = handler.route_error("req-123", error) + + assert result is True + assert resolver.done() + + with pytest.raises(McpError) as exc_info: + await resolver.wait() + assert exc_info.value.error.message == "Something went wrong" + + +@pytest.mark.anyio +async def test_route_error_returns_false_for_unknown_request( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that route_error() returns False for unknown request ID.""" + error = ErrorData(code=INVALID_REQUEST, message="Error") + result = handler.route_error("unknown-req", error) + assert result is False + + +@pytest.mark.anyio +async def test_deliver_registers_resolver_for_request_messages( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that _deliver_queued_messages registers resolvers for request messages.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + + resolver: Resolver[dict[str, Any]] = Resolver() + queued_msg = QueuedMessage( + type="request", + message=JSONRPCRequest( + jsonrpc="2.0", + id="inner-req-1", + method="elicitation/create", + params={}, + ), + resolver=resolver, + original_request_id="inner-req-1", + ) + await queue.enqueue(task.taskId, queued_msg) + + mock_session = Mock() + mock_session.send_message = AsyncMock() + + await handler._deliver_queued_messages(task.taskId, mock_session, "outer-req-1") + + assert "inner-req-1" in handler._pending_requests + assert handler._pending_requests["inner-req-1"] is resolver + + +@pytest.mark.anyio +async def test_deliver_skips_resolver_registration_when_no_original_id( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that _deliver_queued_messages skips resolver registration when original_request_id is None.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + + resolver: Resolver[dict[str, Any]] = Resolver() + queued_msg = QueuedMessage( + type="request", + message=JSONRPCRequest( + jsonrpc="2.0", + id="inner-req-1", + method="elicitation/create", + params={}, + ), + resolver=resolver, + original_request_id=None, # No original request ID + ) + await queue.enqueue(task.taskId, queued_msg) + + mock_session = Mock() + mock_session.send_message = AsyncMock() + + await handler._deliver_queued_messages(task.taskId, mock_session, "outer-req-1") + + # Resolver should NOT be registered since original_request_id is None + assert len(handler._pending_requests) == 0 + # But the message should still be sent + mock_session.send_message.assert_called_once() + + +@pytest.mark.anyio +async def test_wait_for_task_update_handles_store_exception( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that _wait_for_task_update handles store exception gracefully.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + + # Make wait_for_update raise an exception + async def failing_wait(task_id: str) -> None: + raise RuntimeError("Store error") + + store.wait_for_update = failing_wait # type: ignore[method-assign] + + # Queue a message to unblock the race via the queue path + async def enqueue_later() -> None: + # Wait for queue to start waiting (event gets created when wait starts) + while task.taskId not in queue._events: + await anyio.sleep(0) + await queue.enqueue( + task.taskId, + QueuedMessage( + type="notification", + message=JSONRPCRequest( + jsonrpc="2.0", + id="notif-1", + method="test/notification", + params={}, + ), + ), + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(enqueue_later) + # This should complete via the queue path even though store raises + await handler._wait_for_task_update(task.taskId) + + +@pytest.mark.anyio +async def test_wait_for_task_update_handles_queue_exception( + store: InMemoryTaskStore, queue: InMemoryTaskMessageQueue, handler: TaskResultHandler +) -> None: + """Test that _wait_for_task_update handles queue exception gracefully.""" + task = await store.create_task(TaskMetadata(ttl=60000), task_id="test-task") + + # Make wait_for_message raise an exception + async def failing_wait(task_id: str) -> None: + raise RuntimeError("Queue error") + + queue.wait_for_message = failing_wait # type: ignore[method-assign] + + # Update the store to unblock the race via the store path + async def update_later() -> None: + # Wait for store to start waiting (event gets created when wait starts) + while task.taskId not in store._update_events: + await anyio.sleep(0) + await store.update_task(task.taskId, status="completed") + + async with anyio.create_task_group() as tg: + tg.start_soon(update_later) + # This should complete via the store path even though queue raises + await handler._wait_for_task_update(task.taskId) diff --git a/tests/experimental/tasks/test_capabilities.py b/tests/experimental/tasks/test_capabilities.py new file mode 100644 index 0000000000..e78f16fe3f --- /dev/null +++ b/tests/experimental/tasks/test_capabilities.py @@ -0,0 +1,283 @@ +"""Tests for tasks capability checking utilities.""" + +import pytest + +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.capabilities import ( + check_tasks_capability, + has_task_augmented_elicitation, + has_task_augmented_sampling, + require_task_augmented_elicitation, + require_task_augmented_sampling, +) +from mcp.types import ( + ClientCapabilities, + ClientTasksCapability, + ClientTasksRequestsCapability, + TasksCreateElicitationCapability, + TasksCreateMessageCapability, + TasksElicitationCapability, + TasksSamplingCapability, +) + + +class TestCheckTasksCapability: + """Tests for check_tasks_capability function.""" + + def test_required_requests_none_returns_true(self) -> None: + """When required.requests is None, should return True.""" + required = ClientTasksCapability() + client = ClientTasksCapability() + assert check_tasks_capability(required, client) is True + + def test_client_requests_none_returns_false(self) -> None: + """When client.requests is None but required.requests is set, should return False.""" + required = ClientTasksCapability(requests=ClientTasksRequestsCapability()) + client = ClientTasksCapability() + assert check_tasks_capability(required, client) is False + + def test_elicitation_required_but_client_missing(self) -> None: + """When elicitation is required but client doesn't have it.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) + ) + client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) + assert check_tasks_capability(required, client) is False + + def test_elicitation_create_required_but_client_missing(self) -> None: + """When elicitation.create is required but client doesn't have it.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability() # No create + ) + ) + assert check_tasks_capability(required, client) is False + + def test_elicitation_create_present(self) -> None: + """When elicitation.create is required and client has it.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + assert check_tasks_capability(required, client) is True + + def test_sampling_required_but_client_missing(self) -> None: + """When sampling is required but client doesn't have it.""" + required = ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) + client = ClientTasksCapability(requests=ClientTasksRequestsCapability()) + assert check_tasks_capability(required, client) is False + + def test_sampling_create_message_required_but_client_missing(self) -> None: + """When sampling.createMessage is required but client doesn't have it.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability() # No createMessage + ) + ) + assert check_tasks_capability(required, client) is False + + def test_sampling_create_message_present(self) -> None: + """When sampling.createMessage is required and client has it.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + assert check_tasks_capability(required, client) is True + + def test_both_elicitation_and_sampling_present(self) -> None: + """When both elicitation.create and sampling.createMessage are required and client has both.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()), + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()), + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()), + ) + ) + assert check_tasks_capability(required, client) is True + + def test_elicitation_without_create_required(self) -> None: + """When elicitation is required but not create specifically.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability() # No create + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + assert check_tasks_capability(required, client) is True + + def test_sampling_without_create_message_required(self) -> None: + """When sampling is required but not createMessage specifically.""" + required = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability() # No createMessage + ) + ) + client = ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + assert check_tasks_capability(required, client) is True + + +class TestHasTaskAugmentedElicitation: + """Tests for has_task_augmented_elicitation function.""" + + def test_tasks_none(self) -> None: + """Returns False when caps.tasks is None.""" + caps = ClientCapabilities() + assert has_task_augmented_elicitation(caps) is False + + def test_requests_none(self) -> None: + """Returns False when caps.tasks.requests is None.""" + caps = ClientCapabilities(tasks=ClientTasksCapability()) + assert has_task_augmented_elicitation(caps) is False + + def test_elicitation_none(self) -> None: + """Returns False when caps.tasks.requests.elicitation is None.""" + caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) + assert has_task_augmented_elicitation(caps) is False + + def test_create_none(self) -> None: + """Returns False when caps.tasks.requests.elicitation.create is None.""" + caps = ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability(elicitation=TasksElicitationCapability()) + ) + ) + assert has_task_augmented_elicitation(caps) is False + + def test_create_present(self) -> None: + """Returns True when full capability path is present.""" + caps = ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + ) + assert has_task_augmented_elicitation(caps) is True + + +class TestHasTaskAugmentedSampling: + """Tests for has_task_augmented_sampling function.""" + + def test_tasks_none(self) -> None: + """Returns False when caps.tasks is None.""" + caps = ClientCapabilities() + assert has_task_augmented_sampling(caps) is False + + def test_requests_none(self) -> None: + """Returns False when caps.tasks.requests is None.""" + caps = ClientCapabilities(tasks=ClientTasksCapability()) + assert has_task_augmented_sampling(caps) is False + + def test_sampling_none(self) -> None: + """Returns False when caps.tasks.requests.sampling is None.""" + caps = ClientCapabilities(tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability())) + assert has_task_augmented_sampling(caps) is False + + def test_create_message_none(self) -> None: + """Returns False when caps.tasks.requests.sampling.createMessage is None.""" + caps = ClientCapabilities( + tasks=ClientTasksCapability(requests=ClientTasksRequestsCapability(sampling=TasksSamplingCapability())) + ) + assert has_task_augmented_sampling(caps) is False + + def test_create_message_present(self) -> None: + """Returns True when full capability path is present.""" + caps = ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + ) + assert has_task_augmented_sampling(caps) is True + + +class TestRequireTaskAugmentedElicitation: + """Tests for require_task_augmented_elicitation function.""" + + def test_raises_when_none(self) -> None: + """Raises McpError when client_caps is None.""" + with pytest.raises(McpError) as exc_info: + require_task_augmented_elicitation(None) + assert "task-augmented elicitation" in str(exc_info.value) + + def test_raises_when_missing(self) -> None: + """Raises McpError when capability is missing.""" + caps = ClientCapabilities() + with pytest.raises(McpError) as exc_info: + require_task_augmented_elicitation(caps) + assert "task-augmented elicitation" in str(exc_info.value) + + def test_passes_when_present(self) -> None: + """Does not raise when capability is present.""" + caps = ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability( + elicitation=TasksElicitationCapability(create=TasksCreateElicitationCapability()) + ) + ) + ) + require_task_augmented_elicitation(caps) + + +class TestRequireTaskAugmentedSampling: + """Tests for require_task_augmented_sampling function.""" + + def test_raises_when_none(self) -> None: + """Raises McpError when client_caps is None.""" + with pytest.raises(McpError) as exc_info: + require_task_augmented_sampling(None) + assert "task-augmented sampling" in str(exc_info.value) + + def test_raises_when_missing(self) -> None: + """Raises McpError when capability is missing.""" + caps = ClientCapabilities() + with pytest.raises(McpError) as exc_info: + require_task_augmented_sampling(caps) + assert "task-augmented sampling" in str(exc_info.value) + + def test_passes_when_present(self) -> None: + """Does not raise when capability is present.""" + caps = ClientCapabilities( + tasks=ClientTasksCapability( + requests=ClientTasksRequestsCapability( + sampling=TasksSamplingCapability(createMessage=TasksCreateMessageCapability()) + ) + ) + ) + require_task_augmented_sampling(caps) diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py new file mode 100644 index 0000000000..be2b616018 --- /dev/null +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -0,0 +1,737 @@ +""" +Tests for the four elicitation scenarios with tasks. + +This tests all combinations of tool call types and elicitation types: +1. Normal tool call + Normal elicitation (session.elicit) +2. Normal tool call + Task-augmented elicitation (session.experimental.elicit_as_task) +3. Task-augmented tool call + Normal elicitation (task.elicit) +4. Task-augmented tool call + Task-augmented elicitation (task.elicit_as_task) + +And the same for sampling (create_message). +""" + +from typing import Any + +import anyio +import pytest +from anyio import Event + +from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import NotificationOptions +from mcp.shared.context import RequestContext +from mcp.shared.experimental.tasks.helpers import is_terminal +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + CreateTaskResult, + ElicitRequestParams, + ElicitResult, + ErrorData, + GetTaskPayloadResult, + GetTaskResult, + SamplingMessage, + TaskMetadata, + TextContent, + Tool, + ToolExecution, +) + + +def create_client_task_handlers( + client_task_store: InMemoryTaskStore, + elicit_received: Event, +) -> ExperimentalTaskHandlers: + """Create task handlers for client to handle task-augmented elicitation from server.""" + + elicit_response = ElicitResult(action="http://github.com/accept", content={"confirm": True}) + task_complete_events: dict[str, Event] = {} + + async def handle_augmented_elicitation( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, + task_metadata: TaskMetadata, + ) -> CreateTaskResult: + """Handle task-augmented elicitation by creating a client-side task.""" + elicit_received.set() + task = await client_task_store.create_task(task_metadata) + task_complete_events[task.taskId] = Event() + + async def complete_task() -> None: + # Store result before updating status to avoid race condition + await client_task_store.store_result(task.taskId, elicit_response) + await client_task_store.update_task(task.taskId, status="completed") + task_complete_events[task.taskId].set() + + context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] + return CreateTaskResult(task=task) + + async def handle_get_task( + context: RequestContext[ClientSession, Any], + params: Any, + ) -> GetTaskResult: + """Handle tasks/get from server.""" + task = await client_task_store.get_task(params.taskId) + assert task is not None, f"Task not found: {params.taskId}" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=100, + ) + + async def handle_get_task_result( + context: RequestContext[ClientSession, Any], + params: Any, + ) -> GetTaskPayloadResult | ErrorData: + """Handle tasks/result from server.""" + event = task_complete_events.get(params.taskId) + assert event is not None, f"No completion event for task: {params.taskId}" + await event.wait() + result = await client_task_store.get_result(params.taskId) + assert result is not None, f"Result not found for task: {params.taskId}" + return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) + + return ExperimentalTaskHandlers( + augmented_elicitation=handle_augmented_elicitation, + get_task=handle_get_task, + get_task_result=handle_get_task_result, + ) + + +def create_sampling_task_handlers( + client_task_store: InMemoryTaskStore, + sampling_received: Event, +) -> ExperimentalTaskHandlers: + """Create task handlers for client to handle task-augmented sampling from server.""" + + sampling_response = CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Hello from the model!"), + model="test-model", + ) + task_complete_events: dict[str, Event] = {} + + async def handle_augmented_sampling( + context: RequestContext[ClientSession, Any], + params: CreateMessageRequestParams, + task_metadata: TaskMetadata, + ) -> CreateTaskResult: + """Handle task-augmented sampling by creating a client-side task.""" + sampling_received.set() + task = await client_task_store.create_task(task_metadata) + task_complete_events[task.taskId] = Event() + + async def complete_task() -> None: + # Store result before updating status to avoid race condition + await client_task_store.store_result(task.taskId, sampling_response) + await client_task_store.update_task(task.taskId, status="completed") + task_complete_events[task.taskId].set() + + context.session._task_group.start_soon(complete_task) # pyright: ignore[reportPrivateUsage] + return CreateTaskResult(task=task) + + async def handle_get_task( + context: RequestContext[ClientSession, Any], + params: Any, + ) -> GetTaskResult: + """Handle tasks/get from server.""" + task = await client_task_store.get_task(params.taskId) + assert task is not None, f"Task not found: {params.taskId}" + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + lastUpdatedAt=task.lastUpdatedAt, + ttl=task.ttl, + pollInterval=100, + ) + + async def handle_get_task_result( + context: RequestContext[ClientSession, Any], + params: Any, + ) -> GetTaskPayloadResult | ErrorData: + """Handle tasks/result from server.""" + event = task_complete_events.get(params.taskId) + assert event is not None, f"No completion event for task: {params.taskId}" + await event.wait() + result = await client_task_store.get_result(params.taskId) + assert result is not None, f"Result not found for task: {params.taskId}" + return GetTaskPayloadResult.model_validate(result.model_dump(by_alias=True)) + + return ExperimentalTaskHandlers( + augmented_sampling=handle_augmented_sampling, + get_task=handle_get_task, + get_task_result=handle_get_task_result, + ) + + +@pytest.mark.anyio +async def test_scenario1_normal_tool_normal_elicitation() -> None: + """ + Scenario 1: Normal tool call with normal elicitation. + + Server calls session.elicit() directly, client responds immediately. + """ + server = Server("test-scenario1") + elicit_received = Event() + tool_result: list[str] = [] + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="confirm_action", + description="Confirm an action", + inputSchema={"type": "object"}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + ctx = server.request_context + + # Normal elicitation - expects immediate response + result = await ctx.session.elicit( + message="Please confirm the action", + requestedSchema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, + ) + + confirmed = result.content.get("confirm", False) if result.content else False + tool_result.append("confirmed" if confirmed else "cancelled") + return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + + # Elicitation callback for client + async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, + ) -> ElicitResult: + elicit_received.set() + return ElicitResult(action="http://github.com/accept", content={"confirm": True}) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession( + server_to_client_receive, + client_to_server_send, + elicitation_callback=elicitation_callback, + ) as client_session: + await client_session.initialize() + + # Call tool normally (not as task) + result = await client_session.call_tool("confirm_action", {}) + + # Verify elicitation was received and tool completed + assert elicit_received.is_set() + assert len(result.content) > 0 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "confirmed" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + assert tool_result[0] == "confirmed" + + +@pytest.mark.anyio +async def test_scenario2_normal_tool_task_augmented_elicitation() -> None: + """ + Scenario 2: Normal tool call with task-augmented elicitation. + + Server calls session.experimental.elicit_as_task(), client creates a task + for the elicitation and returns CreateTaskResult. Server polls client. + """ + server = Server("test-scenario2") + elicit_received = Event() + tool_result: list[str] = [] + + # Client-side task store for handling task-augmented elicitation + client_task_store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="confirm_action", + description="Confirm an action", + inputSchema={"type": "object"}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + ctx = server.request_context + + # Task-augmented elicitation - server polls client + result = await ctx.session.experimental.elicit_as_task( + message="Please confirm the action", + requestedSchema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, + ttl=60000, + ) + + confirmed = result.content.get("confirm", False) if result.content else False + tool_result.append("confirmed" if confirmed else "cancelled") + return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + + task_handlers = create_client_task_handlers(client_task_store, elicit_received) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as client_session: + await client_session.initialize() + + # Call tool normally (not as task) + result = await client_session.call_tool("confirm_action", {}) + + # Verify elicitation was received and tool completed + assert elicit_received.is_set() + assert len(result.content) > 0 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "confirmed" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + assert tool_result[0] == "confirmed" + client_task_store.cleanup() + + +@pytest.mark.anyio +async def test_scenario3_task_augmented_tool_normal_elicitation() -> None: + """ + Scenario 3: Task-augmented tool call with normal elicitation. + + Client calls tool as task. Inside the task, server uses task.elicit() + which queues the request and delivers via tasks/result. + """ + server = Server("test-scenario3") + server.experimental.enable_tasks() + + elicit_received = Event() + work_completed = Event() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="confirm_action", + description="Confirm an action", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + # Normal elicitation within task - queued and delivered via tasks/result + result = await task.elicit( + message="Please confirm the action", + requestedSchema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, + ) + + confirmed = result.content.get("confirm", False) if result.content else False + work_completed.set() + return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + + return await ctx.experimental.run_task(work) + + # Elicitation callback for client + async def elicitation_callback( + context: RequestContext[ClientSession, Any], + params: ElicitRequestParams, + ) -> ElicitResult: + elicit_received.set() + return ElicitResult(action="http://github.com/accept", content={"confirm": True}) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession( + server_to_client_receive, + client_to_server_send, + elicitation_callback=elicitation_callback, + ) as client_session: + await client_session.initialize() + + # Call tool as task + create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) + task_id = create_result.task.taskId + assert create_result.task.status == "working" + + # Poll until input_required, then call tasks/result + found_input_required = False + async for status in client_session.experimental.poll_task(task_id): # pragma: no branch + if status.status == "input_required": # pragma: no branch + found_input_required = True + break + assert found_input_required, "Expected to see input_required status" + + # This will deliver the elicitation and get the response + final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + + # Verify + assert elicit_received.is_set() + assert len(final_result.content) > 0 + assert isinstance(final_result.content[0], TextContent) + assert final_result.content[0].text == "confirmed" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + assert work_completed.is_set() + + +@pytest.mark.anyio +async def test_scenario4_task_augmented_tool_task_augmented_elicitation() -> None: + """ + Scenario 4: Task-augmented tool call with task-augmented elicitation. + + Client calls tool as task. Inside the task, server uses task.elicit_as_task() + which sends task-augmented elicitation. Client creates its own task for the + elicitation, and server polls the client. + + This tests the full bidirectional flow where: + 1. Client calls tasks/result on server (for tool task) + 2. Server delivers task-augmented elicitation through that stream + 3. Client creates its own task and returns CreateTaskResult + 4. Server polls the client's task while the client's tasks/result is still open + 5. Server gets the ElicitResult and completes the tool task + 6. Client's tasks/result returns with the CallToolResult + """ + server = Server("test-scenario4") + server.experimental.enable_tasks() + + elicit_received = Event() + work_completed = Event() + + # Client-side task store for handling task-augmented elicitation + client_task_store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="confirm_action", + description="Confirm an action", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + # Task-augmented elicitation within task - server polls client + result = await task.elicit_as_task( + message="Please confirm the action", + requestedSchema={"type": "object", "properties": {"confirm": {"type": "boolean"}}}, + ttl=60000, + ) + + confirmed = result.content.get("confirm", False) if result.content else False + work_completed.set() + return CallToolResult(content=[TextContent(type="text", text="confirmed" if confirmed else "cancelled")]) + + return await ctx.experimental.run_task(work) + + task_handlers = create_client_task_handlers(client_task_store, elicit_received) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as client_session: + await client_session.initialize() + + # Call tool as task + create_result = await client_session.experimental.call_tool_as_task("confirm_action", {}) + task_id = create_result.task.taskId + assert create_result.task.status == "working" + + # Poll until input_required or terminal, then call tasks/result + found_expected_status = False + async for status in client_session.experimental.poll_task(task_id): # pragma: no branch + if status.status == "input_required" or is_terminal(status.status): # pragma: no branch + found_expected_status = True + break + assert found_expected_status, "Expected to see input_required or terminal status" + + # This will deliver the task-augmented elicitation, + # server will poll client, and eventually return the tool result + final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + + # Verify + assert elicit_received.is_set() + assert len(final_result.content) > 0 + assert isinstance(final_result.content[0], TextContent) + assert final_result.content[0].text == "confirmed" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + assert work_completed.is_set() + client_task_store.cleanup() + + +@pytest.mark.anyio +async def test_scenario2_sampling_normal_tool_task_augmented_sampling() -> None: + """ + Scenario 2 for sampling: Normal tool call with task-augmented sampling. + + Server calls session.experimental.create_message_as_task(), client creates + a task for the sampling and returns CreateTaskResult. Server polls client. + """ + server = Server("test-scenario2-sampling") + sampling_received = Event() + tool_result: list[str] = [] + + # Client-side task store for handling task-augmented sampling + client_task_store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="generate_text", + description="Generate text using sampling", + inputSchema={"type": "object"}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult: + ctx = server.request_context + + # Task-augmented sampling - server polls client + result = await ctx.session.experimental.create_message_as_task( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + max_tokens=100, + ttl=60000, + ) + + assert isinstance(result.content, TextContent), "Expected TextContent response" + response_text = result.content.text + + tool_result.append(response_text) + return CallToolResult(content=[TextContent(type="text", text=response_text)]) + + task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as client_session: + await client_session.initialize() + + # Call tool normally (not as task) + result = await client_session.call_tool("generate_text", {}) + + # Verify sampling was received and tool completed + assert sampling_received.is_set() + assert len(result.content) > 0 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Hello from the model!" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + assert tool_result[0] == "Hello from the model!" + client_task_store.cleanup() + + +@pytest.mark.anyio +async def test_scenario4_sampling_task_augmented_tool_task_augmented_sampling() -> None: + """ + Scenario 4 for sampling: Task-augmented tool call with task-augmented sampling. + + Client calls tool as task. Inside the task, server uses task.create_message_as_task() + which sends task-augmented sampling. Client creates its own task for the sampling, + and server polls the client. + """ + server = Server("test-scenario4-sampling") + server.experimental.enable_tasks() + + sampling_received = Event() + work_completed = Event() + + # Client-side task store for handling task-augmented sampling + client_task_store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="generate_text", + description="Generate text using sampling", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(TASK_REQUIRED) + + async def work(task: ServerTaskContext) -> CallToolResult: + # Task-augmented sampling within task - server polls client + result = await task.create_message_as_task( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))], + max_tokens=100, + ttl=60000, + ) + + assert isinstance(result.content, TextContent), "Expected TextContent response" + response_text = result.content.text + + work_completed.set() + return CallToolResult(content=[TextContent(type="text", text=response_text)]) + + return await ctx.experimental.run_task(work) + + task_handlers = create_sampling_task_handlers(client_task_store, sampling_received) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ) + + async def run_client() -> None: + async with ClientSession( + server_to_client_receive, + client_to_server_send, + experimental_task_handlers=task_handlers, + ) as client_session: + await client_session.initialize() + + # Call tool as task + create_result = await client_session.experimental.call_tool_as_task("generate_text", {}) + task_id = create_result.task.taskId + assert create_result.task.status == "working" + + # Poll until input_required or terminal + found_expected_status = False + async for status in client_session.experimental.poll_task(task_id): # pragma: no branch + if status.status == "input_required" or is_terminal(status.status): # pragma: no branch + found_expected_status = True + break + assert found_expected_status, "Expected to see input_required or terminal status" + + final_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + + # Verify + assert sampling_received.is_set() + assert len(final_result.content) > 0 + assert isinstance(final_result.content[0], TextContent) + assert final_result.content[0].text == "Hello from the model!" + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + tg.start_soon(run_client) + + assert work_completed.is_set() + client_task_store.cleanup() diff --git a/tests/experimental/tasks/test_message_queue.py b/tests/experimental/tasks/test_message_queue.py new file mode 100644 index 0000000000..86d6875cc4 --- /dev/null +++ b/tests/experimental/tasks/test_message_queue.py @@ -0,0 +1,331 @@ +""" +Tests for TaskMessageQueue and InMemoryTaskMessageQueue. +""" + +from datetime import datetime, timezone + +import anyio +import pytest + +from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, QueuedMessage +from mcp.shared.experimental.tasks.resolver import Resolver +from mcp.types import JSONRPCNotification, JSONRPCRequest + + +@pytest.fixture +def queue() -> InMemoryTaskMessageQueue: + return InMemoryTaskMessageQueue() + + +def make_request(id: int = 1, method: str = "test/method") -> JSONRPCRequest: + return JSONRPCRequest(jsonrpc="2.0", id=id, method=method) + + +def make_notification(method: str = "test/notify") -> JSONRPCNotification: + return JSONRPCNotification(jsonrpc="2.0", method=method) + + +class TestInMemoryTaskMessageQueue: + @pytest.mark.anyio + async def test_enqueue_and_dequeue(self, queue: InMemoryTaskMessageQueue) -> None: + """Test basic enqueue and dequeue operations.""" + task_id = "task-1" + msg = QueuedMessage(type="request", message=make_request()) + + await queue.enqueue(task_id, msg) + result = await queue.dequeue(task_id) + + assert result is not None + assert result.type == "request" + assert result.message.method == "test/method" + + @pytest.mark.anyio + async def test_dequeue_empty_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: + """Dequeue from empty queue returns None.""" + result = await queue.dequeue("nonexistent-task") + assert result is None + + @pytest.mark.anyio + async def test_fifo_ordering(self, queue: InMemoryTaskMessageQueue) -> None: + """Messages are dequeued in FIFO order.""" + task_id = "task-1" + + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1, "first"))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2, "second"))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3, "third"))) + + msg1 = await queue.dequeue(task_id) + msg2 = await queue.dequeue(task_id) + msg3 = await queue.dequeue(task_id) + + assert msg1 is not None and msg1.message.method == "first" + assert msg2 is not None and msg2.message.method == "second" + assert msg3 is not None and msg3.message.method == "third" + + @pytest.mark.anyio + async def test_separate_queues_per_task(self, queue: InMemoryTaskMessageQueue) -> None: + """Each task has its own queue.""" + await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1, "task1-msg"))) + await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2, "task2-msg"))) + + msg1 = await queue.dequeue("task-1") + msg2 = await queue.dequeue("task-2") + + assert msg1 is not None and msg1.message.method == "task1-msg" + assert msg2 is not None and msg2.message.method == "task2-msg" + + @pytest.mark.anyio + async def test_peek_does_not_remove(self, queue: InMemoryTaskMessageQueue) -> None: + """Peek returns message without removing it.""" + task_id = "task-1" + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) + + peeked = await queue.peek(task_id) + dequeued = await queue.dequeue(task_id) + + assert peeked is not None + assert dequeued is not None + assert isinstance(peeked.message, JSONRPCRequest) + assert isinstance(dequeued.message, JSONRPCRequest) + assert peeked.message.id == dequeued.message.id + + @pytest.mark.anyio + async def test_is_empty(self, queue: InMemoryTaskMessageQueue) -> None: + """Test is_empty method.""" + task_id = "task-1" + + assert await queue.is_empty(task_id) is True + + await queue.enqueue(task_id, QueuedMessage(type="notification", message=make_notification())) + assert await queue.is_empty(task_id) is False + + await queue.dequeue(task_id) + assert await queue.is_empty(task_id) is True + + @pytest.mark.anyio + async def test_clear_returns_all_messages(self, queue: InMemoryTaskMessageQueue) -> None: + """Clear removes and returns all messages.""" + task_id = "task-1" + + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(1))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(2))) + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request(3))) + + messages = await queue.clear(task_id) + + assert len(messages) == 3 + assert await queue.is_empty(task_id) is True + + @pytest.mark.anyio + async def test_clear_empty_queue(self, queue: InMemoryTaskMessageQueue) -> None: + """Clear on empty queue returns empty list.""" + messages = await queue.clear("nonexistent") + assert messages == [] + + @pytest.mark.anyio + async def test_notification_messages(self, queue: InMemoryTaskMessageQueue) -> None: + """Test queuing notification messages.""" + task_id = "task-1" + msg = QueuedMessage(type="notification", message=make_notification("log/message")) + + await queue.enqueue(task_id, msg) + result = await queue.dequeue(task_id) + + assert result is not None + assert result.type == "notification" + assert result.message.method == "log/message" + + @pytest.mark.anyio + async def test_message_timestamp(self, queue: InMemoryTaskMessageQueue) -> None: + """Messages have timestamps.""" + before = datetime.now(timezone.utc) + msg = QueuedMessage(type="request", message=make_request()) + after = datetime.now(timezone.utc) + + assert before <= msg.timestamp <= after + + @pytest.mark.anyio + async def test_message_with_resolver(self, queue: InMemoryTaskMessageQueue) -> None: + """Messages can have resolvers.""" + task_id = "task-1" + resolver: Resolver[dict[str, str]] = Resolver() + + msg = QueuedMessage( + type="request", + message=make_request(), + resolver=resolver, + original_request_id=42, + ) + + await queue.enqueue(task_id, msg) + result = await queue.dequeue(task_id) + + assert result is not None + assert result.resolver is resolver + assert result.original_request_id == 42 + + @pytest.mark.anyio + async def test_cleanup_specific_task(self, queue: InMemoryTaskMessageQueue) -> None: + """Cleanup removes specific task's data.""" + await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) + await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) + + queue.cleanup("task-1") + + assert await queue.is_empty("task-1") is True + assert await queue.is_empty("task-2") is False + + @pytest.mark.anyio + async def test_cleanup_all(self, queue: InMemoryTaskMessageQueue) -> None: + """Cleanup without task_id removes all data.""" + await queue.enqueue("task-1", QueuedMessage(type="request", message=make_request(1))) + await queue.enqueue("task-2", QueuedMessage(type="request", message=make_request(2))) + + queue.cleanup() + + assert await queue.is_empty("task-1") is True + assert await queue.is_empty("task-2") is True + + @pytest.mark.anyio + async def test_wait_for_message_returns_immediately_if_message_exists( + self, queue: InMemoryTaskMessageQueue + ) -> None: + """wait_for_message returns immediately if queue not empty.""" + task_id = "task-1" + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) + + # Should return immediately, not block + with anyio.fail_after(1): + await queue.wait_for_message(task_id) + + @pytest.mark.anyio + async def test_wait_for_message_blocks_until_message(self, queue: InMemoryTaskMessageQueue) -> None: + """wait_for_message blocks until a message is enqueued.""" + task_id = "task-1" + received = False + waiter_started = anyio.Event() + + async def enqueue_when_ready() -> None: + # Wait until the waiter has started before enqueueing + await waiter_started.wait() + await queue.enqueue(task_id, QueuedMessage(type="request", message=make_request())) + + async def wait_for_msg() -> None: + nonlocal received + # Signal that we're about to start waiting + waiter_started.set() + await queue.wait_for_message(task_id) + received = True + + async with anyio.create_task_group() as tg: + tg.start_soon(wait_for_msg) + tg.start_soon(enqueue_when_ready) + + assert received is True + + @pytest.mark.anyio + async def test_notify_message_available_wakes_waiter(self, queue: InMemoryTaskMessageQueue) -> None: + """notify_message_available wakes up waiting coroutines.""" + task_id = "task-1" + notified = False + waiter_started = anyio.Event() + + async def notify_when_ready() -> None: + # Wait until the waiter has started before notifying + await waiter_started.wait() + await queue.notify_message_available(task_id) + + async def wait_for_notification() -> None: + nonlocal notified + # Signal that we're about to start waiting + waiter_started.set() + await queue.wait_for_message(task_id) + notified = True + + async with anyio.create_task_group() as tg: + tg.start_soon(wait_for_notification) + tg.start_soon(notify_when_ready) + + assert notified is True + + @pytest.mark.anyio + async def test_peek_empty_queue_returns_none(self, queue: InMemoryTaskMessageQueue) -> None: + """Peek on empty queue returns None.""" + result = await queue.peek("nonexistent-task") + assert result is None + + @pytest.mark.anyio + async def test_wait_for_message_double_check_race_condition(self, queue: InMemoryTaskMessageQueue) -> None: + """wait_for_message returns early if message arrives after event creation but before wait.""" + task_id = "task-1" + + # To test the double-check path (lines 223-225), we need a message to arrive + # after the event is created (line 220) but before event.wait() (line 228). + # We simulate this by injecting a message before is_empty is called the second time. + + original_is_empty = queue.is_empty + call_count = 0 + + async def is_empty_with_injection(tid: str) -> bool: + nonlocal call_count + call_count += 1 + if call_count == 2 and tid == task_id: + # Before second check, inject a message - this simulates a message + # arriving between event creation and the double-check + queue._queues[task_id] = [QueuedMessage(type="request", message=make_request())] + return await original_is_empty(tid) + + queue.is_empty = is_empty_with_injection # type: ignore[method-assign] + + # Should return immediately due to double-check finding the message + with anyio.fail_after(1): + await queue.wait_for_message(task_id) + + +class TestResolver: + @pytest.mark.anyio + async def test_set_result_and_wait(self) -> None: + """Test basic set_result and wait flow.""" + resolver: Resolver[str] = Resolver() + + resolver.set_result("hello") + result = await resolver.wait() + + assert result == "hello" + assert resolver.done() + + @pytest.mark.anyio + async def test_set_exception_and_wait(self) -> None: + """Test set_exception raises on wait.""" + resolver: Resolver[str] = Resolver() + + resolver.set_exception(ValueError("test error")) + + with pytest.raises(ValueError, match="test error"): + await resolver.wait() + + assert resolver.done() + + @pytest.mark.anyio + async def test_set_result_when_already_completed_raises(self) -> None: + """Test that set_result raises if resolver already completed.""" + resolver: Resolver[str] = Resolver() + resolver.set_result("first") + + with pytest.raises(RuntimeError, match="already completed"): + resolver.set_result("second") + + @pytest.mark.anyio + async def test_set_exception_when_already_completed_raises(self) -> None: + """Test that set_exception raises if resolver already completed.""" + resolver: Resolver[str] = Resolver() + resolver.set_result("done") + + with pytest.raises(RuntimeError, match="already completed"): + resolver.set_exception(ValueError("too late")) + + @pytest.mark.anyio + async def test_done_returns_false_before_completion(self) -> None: + """Test done() returns False before any result is set.""" + resolver: Resolver[str] = Resolver() + assert resolver.done() is False diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py new file mode 100644 index 0000000000..5fa5da81af --- /dev/null +++ b/tests/experimental/tasks/test_request_context.py @@ -0,0 +1,166 @@ +"""Tests for the RequestContext.experimental (Experimental class) task validation helpers.""" + +import pytest + +from mcp.server.experimental.request_context import Experimental +from mcp.shared.exceptions import McpError +from mcp.types import ( + METHOD_NOT_FOUND, + TASK_FORBIDDEN, + TASK_OPTIONAL, + TASK_REQUIRED, + ClientCapabilities, + ClientTasksCapability, + TaskMetadata, + Tool, + ToolExecution, +) + + +def test_is_task_true_when_metadata_present() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + assert exp.is_task is True + + +def test_is_task_false_when_no_metadata() -> None: + exp = Experimental(task_metadata=None) + assert exp.is_task is False + + +def test_client_supports_tasks_true() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) + assert exp.client_supports_tasks is True + + +def test_client_supports_tasks_false_no_tasks() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.client_supports_tasks is False + + +def test_client_supports_tasks_false_no_capabilities() -> None: + exp = Experimental(_client_capabilities=None) + assert exp.client_supports_tasks is False + + +def test_validate_task_mode_required_with_task_is_valid() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) + assert error is None + + +def test_validate_task_mode_required_without_task_returns_error() -> None: + exp = Experimental(task_metadata=None) + error = exp.validate_task_mode(TASK_REQUIRED, raise_error=False) + assert error is not None + assert error.code == METHOD_NOT_FOUND + assert "requires task-augmented" in error.message + + +def test_validate_task_mode_required_without_task_raises_by_default() -> None: + exp = Experimental(task_metadata=None) + with pytest.raises(McpError) as exc_info: + exp.validate_task_mode(TASK_REQUIRED) + assert exc_info.value.error.code == METHOD_NOT_FOUND + + +def test_validate_task_mode_forbidden_without_task_is_valid() -> None: + exp = Experimental(task_metadata=None) + error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) + assert error is None + + +def test_validate_task_mode_forbidden_with_task_returns_error() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(TASK_FORBIDDEN, raise_error=False) + assert error is not None + assert error.code == METHOD_NOT_FOUND + assert "does not support task-augmented" in error.message + + +def test_validate_task_mode_forbidden_with_task_raises_by_default() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + with pytest.raises(McpError) as exc_info: + exp.validate_task_mode(TASK_FORBIDDEN) + assert exc_info.value.error.code == METHOD_NOT_FOUND + + +def test_validate_task_mode_none_treated_as_forbidden() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(None, raise_error=False) + assert error is not None + assert "does not support task-augmented" in error.message + + +def test_validate_task_mode_optional_with_task_is_valid() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) + assert error is None + + +def test_validate_task_mode_optional_without_task_is_valid() -> None: + exp = Experimental(task_metadata=None) + error = exp.validate_task_mode(TASK_OPTIONAL, raise_error=False) + assert error is None + + +def test_validate_for_tool_with_execution_required() -> None: + exp = Experimental(task_metadata=None) + tool = Tool( + name="test", + description="test", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + error = exp.validate_for_tool(tool, raise_error=False) + assert error is not None + assert "requires task-augmented" in error.message + + +def test_validate_for_tool_without_execution() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + tool = Tool( + name="test", + description="test", + inputSchema={"type": "object"}, + execution=None, + ) + error = exp.validate_for_tool(tool, raise_error=False) + assert error is not None + assert "does not support task-augmented" in error.message + + +def test_validate_for_tool_optional_with_task() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + tool = Tool( + name="test", + description="test", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_OPTIONAL), + ) + error = exp.validate_for_tool(tool, raise_error=False) + assert error is None + + +def test_can_use_tool_required_with_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities(tasks=ClientTasksCapability())) + assert exp.can_use_tool(TASK_REQUIRED) is True + + +def test_can_use_tool_required_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(TASK_REQUIRED) is False + + +def test_can_use_tool_optional_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(TASK_OPTIONAL) is True + + +def test_can_use_tool_forbidden_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(TASK_FORBIDDEN) is True + + +def test_can_use_tool_none_without_task_support() -> None: + exp = Experimental(_client_capabilities=ClientCapabilities()) + assert exp.can_use_tool(None) is True diff --git a/tests/experimental/tasks/test_spec_compliance.py b/tests/experimental/tasks/test_spec_compliance.py new file mode 100644 index 0000000000..842bfa7e1f --- /dev/null +++ b/tests/experimental/tasks/test_spec_compliance.py @@ -0,0 +1,753 @@ +""" +Tasks Spec Compliance Tests +=========================== + +Test structure mirrors: https://modelcontextprotocol.io/specification/draft/basic/utilities/tasks.md + +Each section contains tests for normative requirements (MUST/SHOULD/MAY). +""" + +from datetime import datetime, timezone + +import pytest + +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY +from mcp.types import ( + CancelTaskRequest, + CancelTaskResult, + CreateTaskResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerCapabilities, + Task, +) + +# Shared test datetime +TEST_DATETIME = datetime(2025, 1, 1, tzinfo=timezone.utc) + + +def _get_capabilities(server: Server) -> ServerCapabilities: + """Helper to get capabilities from a server.""" + return server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + +def test_server_without_task_handlers_has_no_tasks_capability() -> None: + """Server without any task handlers has no tasks capability.""" + server: Server = Server("test") + caps = _get_capabilities(server) + assert caps.tasks is None + + +def test_server_with_list_tasks_handler_declares_list_capability() -> None: + """Server with list_tasks handler declares tasks.list capability.""" + server: Server = Server("test") + + @server.experimental.list_tasks() + async def handle_list(req: ListTasksRequest) -> ListTasksResult: + raise NotImplementedError + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.list is not None + + +def test_server_with_cancel_task_handler_declares_cancel_capability() -> None: + """Server with cancel_task handler declares tasks.cancel capability.""" + server: Server = Server("test") + + @server.experimental.cancel_task() + async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: + raise NotImplementedError + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.cancel is not None + + +def test_server_with_get_task_handler_declares_requests_tools_call_capability() -> None: + """ + Server with get_task handler declares tasks.requests.tools.call capability. + (get_task is required for task-augmented tools/call support) + """ + server: Server = Server("test") + + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + raise NotImplementedError + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.tools is not None + + +def test_server_without_list_handler_has_no_list_capability() -> None: + """Server without list_tasks handler has no tasks.list capability.""" + server: Server = Server("test") + + # Register only get_task (not list_tasks) + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + raise NotImplementedError + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.list is None + + +def test_server_without_cancel_handler_has_no_cancel_capability() -> None: + """Server without cancel_task handler has no tasks.cancel capability.""" + server: Server = Server("test") + + # Register only get_task (not cancel_task) + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + raise NotImplementedError + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.cancel is None + + +def test_server_with_all_task_handlers_has_full_capability() -> None: + """Server with all task handlers declares complete tasks capability.""" + server: Server = Server("test") + + @server.experimental.list_tasks() + async def handle_list(req: ListTasksRequest) -> ListTasksResult: + raise NotImplementedError + + @server.experimental.cancel_task() + async def handle_cancel(req: CancelTaskRequest) -> CancelTaskResult: + raise NotImplementedError + + @server.experimental.get_task() + async def handle_get(req: GetTaskRequest) -> GetTaskResult: + raise NotImplementedError + + caps = _get_capabilities(server) + assert caps.tasks is not None + assert caps.tasks.list is not None + assert caps.tasks.cancel is not None + assert caps.tasks.requests is not None + assert caps.tasks.requests.tools is not None + + +class TestClientCapabilities: + """ + Clients declare: + - tasks.list — supports listing operations + - tasks.cancel — supports cancellation + - tasks.requests.sampling.createMessage — task-augmented sampling + - tasks.requests.elicitation.create — task-augmented elicitation + """ + + def test_client_declares_tasks_capability(self) -> None: + """Client can declare tasks capability.""" + pytest.skip("TODO") + + +class TestToolLevelNegotiation: + """ + Tools in tools/list responses include execution.taskSupport with values: + - Not present or "forbidden": No task augmentation allowed + - "optional": Task augmentation allowed at requestor discretion + - "required": Task augmentation is mandatory + """ + + def test_tool_execution_task_forbidden_rejects_task_augmented_call(self) -> None: + """Tool with execution.taskSupport="forbidden" MUST reject task-augmented calls (-32601).""" + pytest.skip("TODO") + + def test_tool_execution_task_absent_rejects_task_augmented_call(self) -> None: + """Tool without execution.taskSupport MUST reject task-augmented calls (-32601).""" + pytest.skip("TODO") + + def test_tool_execution_task_optional_accepts_normal_call(self) -> None: + """Tool with execution.taskSupport="optional" accepts normal calls.""" + pytest.skip("TODO") + + def test_tool_execution_task_optional_accepts_task_augmented_call(self) -> None: + """Tool with execution.taskSupport="optional" accepts task-augmented calls.""" + pytest.skip("TODO") + + def test_tool_execution_task_required_rejects_normal_call(self) -> None: + """Tool with execution.taskSupport="required" MUST reject non-task calls (-32601).""" + pytest.skip("TODO") + + def test_tool_execution_task_required_accepts_task_augmented_call(self) -> None: + """Tool with execution.taskSupport="required" accepts task-augmented calls.""" + pytest.skip("TODO") + + +class TestCapabilityNegotiation: + """ + Requestors SHOULD only augment requests with a task if the corresponding + capability has been declared by the receiver. + + Receivers that do not declare the task capability for a request type + MUST process requests of that type normally, ignoring any task-augmentation + metadata if present. + """ + + def test_receiver_without_capability_ignores_task_metadata(self) -> None: + """ + Receiver without task capability MUST process request normally, + ignoring task-augmentation metadata. + """ + pytest.skip("TODO") + + def test_receiver_with_capability_may_require_task_augmentation(self) -> None: + """ + Receivers that declare task capability MAY return error (-32600) + for non-task-augmented requests, requiring task augmentation. + """ + pytest.skip("TODO") + + +class TestTaskStatusLifecycle: + """ + Tasks begin in working status and follow valid transitions: + working → input_required → working → terminal + working → terminal (directly) + input_required → terminal (directly) + + Terminal states (no further transitions allowed): + - completed + - failed + - cancelled + """ + + def test_task_begins_in_working_status(self) -> None: + """Tasks MUST begin in working status.""" + pytest.skip("TODO") + + def test_working_to_completed_transition(self) -> None: + """working → completed is valid.""" + pytest.skip("TODO") + + def test_working_to_failed_transition(self) -> None: + """working → failed is valid.""" + pytest.skip("TODO") + + def test_working_to_cancelled_transition(self) -> None: + """working → cancelled is valid.""" + pytest.skip("TODO") + + def test_working_to_input_required_transition(self) -> None: + """working → input_required is valid.""" + pytest.skip("TODO") + + def test_input_required_to_working_transition(self) -> None: + """input_required → working is valid.""" + pytest.skip("TODO") + + def test_input_required_to_terminal_transition(self) -> None: + """input_required → terminal is valid.""" + pytest.skip("TODO") + + def test_terminal_state_no_further_transitions(self) -> None: + """Terminal states allow no further transitions.""" + pytest.skip("TODO") + + def test_completed_is_terminal(self) -> None: + """completed is a terminal state.""" + pytest.skip("TODO") + + def test_failed_is_terminal(self) -> None: + """failed is a terminal state.""" + pytest.skip("TODO") + + def test_cancelled_is_terminal(self) -> None: + """cancelled is a terminal state.""" + pytest.skip("TODO") + + +class TestInputRequiredStatus: + """ + When a receiver needs information to proceed, it moves the task to input_required. + The requestor should call tasks/result to retrieve input requests. + The task must include io.modelcontextprotocol/related-task metadata in associated requests. + """ + + def test_input_required_status_retrievable_via_tasks_get(self) -> None: + """Task in input_required status is retrievable via tasks/get.""" + pytest.skip("TODO") + + def test_input_required_related_task_metadata_in_requests(self) -> None: + """ + Task MUST include io.modelcontextprotocol/related-task metadata + in associated requests. + """ + pytest.skip("TODO") + + +class TestCreatingTask: + """ + Request structure: + {"method": "tools/call", "params": {"name": "...", "arguments": {...}, "task": {"ttl": 60000}}} + + Response (CreateTaskResult): + {"result": {"task": {"taskId": "...", "status": "working", ...}}} + + Receivers may include io.modelcontextprotocol/model-immediate-response in _meta. + """ + + def test_task_augmented_request_returns_create_task_result(self) -> None: + """Task-augmented request MUST return CreateTaskResult immediately.""" + pytest.skip("TODO") + + def test_create_task_result_contains_task_id(self) -> None: + """CreateTaskResult MUST contain taskId.""" + pytest.skip("TODO") + + def test_create_task_result_contains_status_working(self) -> None: + """CreateTaskResult MUST have status=working initially.""" + pytest.skip("TODO") + + def test_create_task_result_contains_created_at(self) -> None: + """CreateTaskResult MUST contain createdAt timestamp.""" + pytest.skip("TODO") + + def test_create_task_result_created_at_is_iso8601(self) -> None: + """createdAt MUST be ISO 8601 formatted.""" + pytest.skip("TODO") + + def test_create_task_result_may_contain_ttl(self) -> None: + """CreateTaskResult MAY contain ttl.""" + pytest.skip("TODO") + + def test_create_task_result_may_contain_poll_interval(self) -> None: + """CreateTaskResult MAY contain pollInterval.""" + pytest.skip("TODO") + + def test_create_task_result_may_contain_status_message(self) -> None: + """CreateTaskResult MAY contain statusMessage.""" + pytest.skip("TODO") + + def test_receiver_may_override_requested_ttl(self) -> None: + """Receiver MAY override requested ttl but MUST return actual value.""" + pytest.skip("TODO") + + def test_model_immediate_response_in_meta(self) -> None: + """ + Receiver MAY include io.modelcontextprotocol/model-immediate-response + in _meta to provide immediate response while task executes. + """ + # Verify the constant has the correct value per spec + assert MODEL_IMMEDIATE_RESPONSE_KEY == "io.modelcontextprotocol/model-immediate-response" + + # CreateTaskResult can include model-immediate-response in _meta + task = Task( + taskId="test-123", + status="working", + createdAt=TEST_DATETIME, + lastUpdatedAt=TEST_DATETIME, + ttl=60000, + ) + immediate_msg = "Task started, processing your request..." + # Note: Must use _meta= (alias) not meta= due to Pydantic alias handling + result = CreateTaskResult( + task=task, + **{"_meta": {MODEL_IMMEDIATE_RESPONSE_KEY: immediate_msg}}, + ) + + # Verify the metadata is present and correct + assert result.meta is not None + assert MODEL_IMMEDIATE_RESPONSE_KEY in result.meta + assert result.meta[MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg + + # Verify it serializes correctly with _meta alias + serialized = result.model_dump(by_alias=True) + assert "_meta" in serialized + assert MODEL_IMMEDIATE_RESPONSE_KEY in serialized["_meta"] + assert serialized["_meta"][MODEL_IMMEDIATE_RESPONSE_KEY] == immediate_msg + + +class TestGettingTaskStatus: + """ + Request: {"method": "tasks/get", "params": {"taskId": "..."}} + Response: Returns full Task object with current status and pollInterval. + """ + + def test_tasks_get_returns_task_object(self) -> None: + """tasks/get MUST return full Task object.""" + pytest.skip("TODO") + + def test_tasks_get_returns_current_status(self) -> None: + """tasks/get MUST return current status.""" + pytest.skip("TODO") + + def test_tasks_get_may_return_poll_interval(self) -> None: + """tasks/get MAY return pollInterval.""" + pytest.skip("TODO") + + def test_tasks_get_invalid_task_id_returns_error(self) -> None: + """tasks/get with invalid taskId MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_get_nonexistent_task_id_returns_error(self) -> None: + """tasks/get with nonexistent taskId MUST return -32602.""" + pytest.skip("TODO") + + +class TestRetrievingResults: + """ + Request: {"method": "tasks/result", "params": {"taskId": "..."}} + Response: The actual operation result structure (e.g., CallToolResult). + + This call blocks until terminal status. + """ + + def test_tasks_result_returns_underlying_result(self) -> None: + """tasks/result MUST return exactly what underlying request would return.""" + pytest.skip("TODO") + + def test_tasks_result_blocks_until_terminal(self) -> None: + """tasks/result MUST block for non-terminal tasks.""" + pytest.skip("TODO") + + def test_tasks_result_unblocks_on_terminal(self) -> None: + """tasks/result MUST unblock upon reaching terminal status.""" + pytest.skip("TODO") + + def test_tasks_result_includes_related_task_metadata(self) -> None: + """tasks/result MUST include io.modelcontextprotocol/related-task in _meta.""" + pytest.skip("TODO") + + def test_tasks_result_returns_error_for_failed_task(self) -> None: + """ + tasks/result returns the same error the underlying request + would have produced for failed tasks. + """ + pytest.skip("TODO") + + def test_tasks_result_invalid_task_id_returns_error(self) -> None: + """tasks/result with invalid taskId MUST return -32602.""" + pytest.skip("TODO") + + +class TestListingTasks: + """ + Request: {"method": "tasks/list", "params": {"cursor": "optional"}} + Response: Array of tasks with pagination support via nextCursor. + """ + + def test_tasks_list_returns_array_of_tasks(self) -> None: + """tasks/list MUST return array of tasks.""" + pytest.skip("TODO") + + def test_tasks_list_pagination_with_cursor(self) -> None: + """tasks/list supports pagination via cursor.""" + pytest.skip("TODO") + + def test_tasks_list_returns_next_cursor_when_more_results(self) -> None: + """tasks/list MUST return nextCursor when more results available.""" + pytest.skip("TODO") + + def test_tasks_list_cursors_are_opaque(self) -> None: + """Implementers MUST treat cursors as opaque tokens.""" + pytest.skip("TODO") + + def test_tasks_list_invalid_cursor_returns_error(self) -> None: + """tasks/list with invalid cursor MUST return -32602.""" + pytest.skip("TODO") + + +class TestCancellingTasks: + """ + Request: {"method": "tasks/cancel", "params": {"taskId": "..."}} + Response: Returns the task object with status: "cancelled". + """ + + def test_tasks_cancel_returns_cancelled_task(self) -> None: + """tasks/cancel MUST return task with status=cancelled.""" + pytest.skip("TODO") + + def test_tasks_cancel_terminal_task_returns_error(self) -> None: + """Cancelling already-terminal task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_completed_task_returns_error(self) -> None: + """Cancelling completed task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_failed_task_returns_error(self) -> None: + """Cancelling failed task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_already_cancelled_task_returns_error(self) -> None: + """Cancelling already-cancelled task MUST return -32602.""" + pytest.skip("TODO") + + def test_tasks_cancel_invalid_task_id_returns_error(self) -> None: + """tasks/cancel with invalid taskId MUST return -32602.""" + pytest.skip("TODO") + + +class TestStatusNotifications: + """ + Receivers MAY send: {"method": "notifications/tasks/status", "params": {...}} + These are optional; requestors MUST NOT rely on them and SHOULD continue polling. + """ + + def test_receiver_may_send_status_notification(self) -> None: + """Receiver MAY send notifications/tasks/status.""" + pytest.skip("TODO") + + def test_status_notification_contains_task_id(self) -> None: + """Status notification MUST contain taskId.""" + pytest.skip("TODO") + + def test_status_notification_contains_status(self) -> None: + """Status notification MUST contain status.""" + pytest.skip("TODO") + + +class TestTaskManagement: + """ + - Receivers generate unique task IDs as strings + - Tasks must begin in working status + - createdAt timestamps must be ISO 8601 formatted + - Receivers may override requested ttl but must return actual value + - Receivers may delete tasks after TTL expires + - All task-related messages must include io.modelcontextprotocol/related-task + in _meta except for tasks/get, tasks/list, tasks/cancel operations + """ + + def test_task_ids_are_unique_strings(self) -> None: + """Receivers MUST generate unique task IDs as strings.""" + pytest.skip("TODO") + + def test_multiple_tasks_have_unique_ids(self) -> None: + """Multiple tasks MUST have unique IDs.""" + pytest.skip("TODO") + + def test_receiver_may_delete_tasks_after_ttl(self) -> None: + """Receivers MAY delete tasks after TTL expires.""" + pytest.skip("TODO") + + def test_related_task_metadata_in_task_messages(self) -> None: + """ + All task-related messages MUST include io.modelcontextprotocol/related-task + in _meta. + """ + pytest.skip("TODO") + + def test_tasks_get_does_not_require_related_task_metadata(self) -> None: + """tasks/get does not require related-task metadata.""" + pytest.skip("TODO") + + def test_tasks_list_does_not_require_related_task_metadata(self) -> None: + """tasks/list does not require related-task metadata.""" + pytest.skip("TODO") + + def test_tasks_cancel_does_not_require_related_task_metadata(self) -> None: + """tasks/cancel does not require related-task metadata.""" + pytest.skip("TODO") + + +class TestResultHandling: + """ + - Receivers must return CreateTaskResult immediately upon accepting task-augmented requests + - tasks/result must return exactly what the underlying request would return + - tasks/result blocks for non-terminal tasks; must unblock upon reaching terminal status + """ + + def test_create_task_result_returned_immediately(self) -> None: + """Receiver MUST return CreateTaskResult immediately (not after work completes).""" + pytest.skip("TODO") + + def test_tasks_result_matches_underlying_result_structure(self) -> None: + """tasks/result MUST return same structure as underlying request.""" + pytest.skip("TODO") + + def test_tasks_result_for_tool_call_returns_call_tool_result(self) -> None: + """tasks/result for tools/call returns CallToolResult.""" + pytest.skip("TODO") + + +class TestProgressTracking: + """ + Task-augmented requests support progress notifications using the progressToken + mechanism, which remains valid throughout the task lifetime. + """ + + def test_progress_token_valid_throughout_task_lifetime(self) -> None: + """progressToken remains valid throughout task lifetime.""" + pytest.skip("TODO") + + def test_progress_notifications_sent_during_task_execution(self) -> None: + """Progress notifications can be sent during task execution.""" + pytest.skip("TODO") + + +class TestProtocolErrors: + """ + Protocol Errors (JSON-RPC standard codes): + - -32600 (Invalid request): Non-task requests to endpoint requiring task augmentation + - -32602 (Invalid params): Invalid/nonexistent taskId, invalid cursor, cancel terminal task + - -32603 (Internal error): Server-side execution failures + """ + + def test_invalid_request_for_required_task_augmentation(self) -> None: + """Non-task request to task-required endpoint returns -32600.""" + pytest.skip("TODO") + + def test_invalid_params_for_invalid_task_id(self) -> None: + """Invalid taskId returns -32602.""" + pytest.skip("TODO") + + def test_invalid_params_for_nonexistent_task_id(self) -> None: + """Nonexistent taskId returns -32602.""" + pytest.skip("TODO") + + def test_invalid_params_for_invalid_cursor(self) -> None: + """Invalid cursor in tasks/list returns -32602.""" + pytest.skip("TODO") + + def test_invalid_params_for_cancel_terminal_task(self) -> None: + """Attempt to cancel terminal task returns -32602.""" + pytest.skip("TODO") + + def test_internal_error_for_server_failure(self) -> None: + """Server-side execution failure returns -32603.""" + pytest.skip("TODO") + + +class TestTaskExecutionErrors: + """ + When underlying requests fail, the task moves to failed status. + - tasks/get response should include statusMessage explaining failure + - tasks/result returns same error the underlying request would have produced + - For tool calls, isError: true moves task to failed status + """ + + def test_underlying_failure_moves_task_to_failed(self) -> None: + """Underlying request failure moves task to failed status.""" + pytest.skip("TODO") + + def test_failed_task_has_status_message(self) -> None: + """Failed task SHOULD include statusMessage explaining failure.""" + pytest.skip("TODO") + + def test_tasks_result_returns_underlying_error(self) -> None: + """tasks/result returns same error underlying request would produce.""" + pytest.skip("TODO") + + def test_tool_call_is_error_true_moves_to_failed(self) -> None: + """Tool call with isError: true moves task to failed status.""" + pytest.skip("TODO") + + +class TestTaskObject: + """ + Task Object fields: + - taskId: String identifier + - status: Current execution state + - statusMessage: Optional human-readable description + - createdAt: ISO 8601 timestamp of creation + - ttl: Milliseconds before potential deletion + - pollInterval: Suggested milliseconds between polls + """ + + def test_task_has_task_id_string(self) -> None: + """Task MUST have taskId as string.""" + pytest.skip("TODO") + + def test_task_has_status(self) -> None: + """Task MUST have status.""" + pytest.skip("TODO") + + def test_task_status_message_is_optional(self) -> None: + """Task statusMessage is optional.""" + pytest.skip("TODO") + + def test_task_has_created_at(self) -> None: + """Task MUST have createdAt.""" + pytest.skip("TODO") + + def test_task_ttl_is_optional(self) -> None: + """Task ttl is optional.""" + pytest.skip("TODO") + + def test_task_poll_interval_is_optional(self) -> None: + """Task pollInterval is optional.""" + pytest.skip("TODO") + + +class TestRelatedTaskMetadata: + """ + Related Task Metadata structure: + {"_meta": {"io.modelcontextprotocol/related-task": {"taskId": "..."}}} + """ + + def test_related_task_metadata_structure(self) -> None: + """Related task metadata has correct structure.""" + pytest.skip("TODO") + + def test_related_task_metadata_contains_task_id(self) -> None: + """Related task metadata contains taskId.""" + pytest.skip("TODO") + + +class TestAccessAndIsolation: + """ + - Task IDs enable access to sensitive results + - Authorization context binding is essential where available + - For non-authorized environments: strong entropy IDs, strict TTL limits + """ + + def test_task_bound_to_authorization_context(self) -> None: + """ + Receivers receiving authorization context MUST bind tasks to that context. + """ + pytest.skip("TODO") + + def test_reject_task_operations_outside_authorization_context(self) -> None: + """ + Receivers MUST reject task operations for tasks outside + requestor's authorization context. + """ + pytest.skip("TODO") + + def test_non_authorized_environments_use_secure_ids(self) -> None: + """ + For non-authorized environments, receivers SHOULD use + cryptographically secure IDs. + """ + pytest.skip("TODO") + + def test_non_authorized_environments_use_shorter_ttls(self) -> None: + """ + For non-authorized environments, receivers SHOULD use shorter TTLs. + """ + pytest.skip("TODO") + + +class TestResourceLimits: + """ + Receivers should: + - Enforce concurrent task limits per requestor + - Implement maximum TTL constraints + - Clean up expired tasks promptly + """ + + def test_concurrent_task_limit_enforced(self) -> None: + """Receiver SHOULD enforce concurrent task limits per requestor.""" + pytest.skip("TODO") + + def test_maximum_ttl_constraint_enforced(self) -> None: + """Receiver SHOULD implement maximum TTL constraints.""" + pytest.skip("TODO") + + def test_expired_tasks_cleaned_up(self) -> None: + """Receiver SHOULD clean up expired tasks promptly.""" + pytest.skip("TODO") diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index 6dccec84d9..9e3447b741 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -13,7 +13,7 @@ async def test_list_tools_returns_all_tools(): for i in range(num_tools): @mcp.tool(name=f"tool_{i}") - def dummy_tool_func(): + def dummy_tool_func(): # pragma: no cover f"""Tool number {i}""" return i diff --git a/tests/issues/test_1027_win_unreachable_cleanup.py b/tests/issues/test_1027_win_unreachable_cleanup.py index 637f7963b2..63d6dd8dcf 100644 --- a/tests/issues/test_1027_win_unreachable_cleanup.py +++ b/tests/issues/test_1027_win_unreachable_cleanup.py @@ -95,7 +95,7 @@ def echo(text: str) -> str: async with ClientSession(read, write) as session: # Initialize the session result = await session.initialize() - assert result.protocolVersion in ["2024-11-05", "2025-06-18"] + assert result.protocolVersion in ["2024-11-05", "2025-06-18", "2025-11-25"] # Verify startup marker was created assert Path(startup_marker).exists(), "Server startup marker not created" @@ -110,7 +110,7 @@ def echo(text: str) -> str: # Give server a moment to complete cleanup with anyio.move_on_after(5.0): - while not Path(cleanup_marker).exists(): + while not Path(cleanup_marker).exists(): # pragma: no cover await anyio.sleep(0.1) # Verify cleanup marker was created - this works now that stdio_client @@ -121,9 +121,9 @@ def echo(text: str) -> str: finally: # Clean up files for path in [server_script, startup_marker, cleanup_marker]: - try: + try: # pragma: no cover Path(path).unlink() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover pass @@ -213,27 +213,27 @@ def echo(text: str) -> str: await anyio.sleep(0.1) # Check if process is still running - if hasattr(process, "returncode") and process.returncode is not None: + if hasattr(process, "returncode") and process.returncode is not None: # pragma: no cover pytest.fail(f"Server process exited with code {process.returncode}") assert Path(startup_marker).exists(), "Server startup marker not created" # Close stdin to signal shutdown - if process.stdin: + if process.stdin: # pragma: no branch await process.stdin.aclose() # Wait for process to exit gracefully try: with anyio.fail_after(5.0): # Increased from 2.0 to 5.0 await process.wait() - except TimeoutError: + except TimeoutError: # pragma: no cover # If it doesn't exit after stdin close, terminate it process.terminate() await process.wait() # Check if cleanup ran with anyio.move_on_after(5.0): - while not Path(cleanup_marker).exists(): + while not Path(cleanup_marker).exists(): # pragma: no cover await anyio.sleep(0.1) # Verify the cleanup ran - stdin closure enables graceful shutdown @@ -243,7 +243,7 @@ def echo(text: str) -> str: finally: # Clean up files for path in [server_script, startup_marker, cleanup_marker]: - try: + try: # pragma: no cover Path(path).unlink() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover pass diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index ec9264c471..958773d127 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -11,12 +11,12 @@ async def test_resource_templates(): # Add a dynamic greeting resource @mcp.resource("greeting://{name}") - def get_greeting(name: str) -> str: + def get_greeting(name: str) -> str: # pragma: no cover """Get a personalized greeting""" return f"Hello, {name}!" @mcp.resource("users://{user_id}/profile") - def get_user_profile(user_id: str) -> str: + def get_user_profile(user_id: str) -> str: # pragma: no cover """Dynamic user data""" return f"Profile data for user {user_id}" @@ -33,10 +33,10 @@ def get_user_profile(user_id: str) -> str: assert len(templates) == 2 # Verify template details - greeting_template = next(t for t in templates if t.name == "get_greeting") + greeting_template = next(t for t in templates if t.name == "get_greeting") # pragma: no cover assert greeting_template.uriTemplate == "greeting://{name}" assert greeting_template.description == "Get a personalized greeting" - profile_template = next(t for t in templates if t.name == "get_user_profile") + profile_template = next(t for t in templates if t.name == "get_user_profile") # pragma: no cover assert profile_template.uriTemplate == "users://{user_id}/profile" assert profile_template.description == "Dynamic user data" diff --git a/tests/issues/test_1338_icons_and_metadata.py b/tests/issues/test_1338_icons_and_metadata.py index 8a9897fcf7..adc37f1c6e 100644 --- a/tests/issues/test_1338_icons_and_metadata.py +++ b/tests/issues/test_1338_icons_and_metadata.py @@ -23,25 +23,25 @@ async def test_icons_and_website_url(): # Create tool with icon @mcp.tool(icons=[test_icon]) - def test_tool(message: str) -> str: + def test_tool(message: str) -> str: # pragma: no cover """A test tool with an icon.""" return message # Create resource with icon @mcp.resource("test://resource", icons=[test_icon]) - def test_resource() -> str: + def test_resource() -> str: # pragma: no cover """A test resource with an icon.""" return "test content" # Create prompt with icon @mcp.prompt("test_prompt", icons=[test_icon]) - def test_prompt(text: str) -> str: + def test_prompt(text: str) -> str: # pragma: no cover """A test prompt with an icon.""" return text # Create resource template with icon @mcp.resource("test://weather/{city}", icons=[test_icon]) - def test_resource_template(city: str) -> str: + def test_resource_template(city: str) -> str: # pragma: no cover """Get weather for a city.""" return f"Weather for {city}" @@ -104,7 +104,7 @@ async def test_multiple_icons(): # Create tool with multiple icons @mcp.tool(icons=[icon1, icon2, icon3]) - def multi_icon_tool() -> str: + def multi_icon_tool() -> str: # pragma: no cover """A tool with multiple icons.""" return "success" @@ -125,7 +125,7 @@ async def test_no_icons_or_website(): mcp = FastMCP("BasicServer") @mcp.tool() - def basic_tool() -> str: + def basic_tool() -> str: # pragma: no cover """A basic tool without icons.""" return "success" diff --git a/tests/issues/test_1363_race_condition_streamable_http.py b/tests/issues/test_1363_race_condition_streamable_http.py new file mode 100644 index 0000000000..49242d6d8b --- /dev/null +++ b/tests/issues/test_1363_race_condition_streamable_http.py @@ -0,0 +1,278 @@ +"""Test for issue #1363 - Race condition in StreamableHTTP transport causes ClosedResourceError. + +This test reproduces the race condition described in issue #1363 where MCP servers +in HTTP Streamable mode experience ClosedResourceError exceptions when requests +fail validation early (e.g., due to incorrect Accept headers). + +The race condition occurs because: +1. Transport setup creates a message_router task +2. Message router enters async for write_stream_reader loop +3. write_stream_reader calls checkpoint() in receive(), yielding control +4. Request handling processes HTTP request +5. If validation fails early, request returns immediately +6. Transport termination closes all streams including write_stream_reader +7. Message router may still be in checkpoint() yield and hasn't returned to check stream state +8. When message router resumes, it encounters a closed stream, raising ClosedResourceError +""" + +import logging +import threading +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import anyio +import httpx +import pytest +from starlette.applications import Starlette +from starlette.routing import Mount + +from mcp.server import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + +SERVER_NAME = "test_race_condition_server" + + +class RaceConditionTestServer(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + +def create_app(json_response: bool = False) -> Starlette: + """Create a Starlette application for testing.""" + app = RaceConditionTestServer() + + # Create session manager + session_manager = StreamableHTTPSessionManager( + app=app, + json_response=json_response, + stateless=True, # Use stateless mode to trigger the race condition + ) + + # Create Starlette app with lifespan + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: + async with session_manager.run(): + yield + + routes = [ + Mount("/", app=session_manager.handle_request), + ] + + return Starlette(routes=routes, lifespan=lifespan) + + +class ServerThread(threading.Thread): + """Thread that runs the ASGI application lifespan in a separate event loop.""" + + def __init__(self, app: Starlette): + super().__init__(daemon=True) + self.app = app + self._stop_event = threading.Event() + + def run(self) -> None: + """Run the lifespan in a new event loop.""" + + # Create a new event loop for this thread + async def run_lifespan(): + # Use the lifespan context (always present in our tests) + lifespan_context = getattr(self.app.router, "lifespan_context", None) + assert lifespan_context is not None # Tests always create apps with lifespan + async with lifespan_context(self.app): + # Wait until stop is requested + while not self._stop_event.is_set(): + await anyio.sleep(0.1) + + anyio.run(run_lifespan) + + def stop(self) -> None: + """Signal the thread to stop.""" + self._stop_event.set() + + +def check_logs_for_race_condition_errors(caplog: pytest.LogCaptureFixture, test_name: str) -> None: + """ + Check logs for ClosedResourceError and other race condition errors. + + Args: + caplog: pytest log capture fixture + test_name: Name of the test for better error messages + """ + # Check for specific race condition errors in logs + errors_found: list[str] = [] + + for record in caplog.records: # pragma: no cover + message = record.getMessage() + if "ClosedResourceError" in message: + errors_found.append("ClosedResourceError") + if "Error in message router" in message: + errors_found.append("Error in message router") + if "anyio.ClosedResourceError" in message: + errors_found.append("anyio.ClosedResourceError") + + # Assert no race condition errors occurred + if errors_found: # pragma: no cover + error_msg = f"Test '{test_name}' found race condition errors in logs: {', '.join(set(errors_found))}\n" + error_msg += "Log records:\n" + for record in caplog.records: + if any(err in record.getMessage() for err in ["ClosedResourceError", "Error in message router"]): + error_msg += f" {record.levelname}: {record.getMessage()}\n" + pytest.fail(error_msg) + + +@pytest.mark.anyio +async def test_race_condition_invalid_accept_headers(caplog: pytest.LogCaptureFixture): + """ + Test the race condition with invalid Accept headers. + + This test reproduces the exact scenario described in issue #1363: + - Send POST request with incorrect Accept headers (missing either application/json or text/event-stream) + - Request fails validation early and returns quickly + - This should trigger the race condition where message_router encounters ClosedResourceError + """ + app = create_app() + server_thread = ServerThread(app) + server_thread.start() + + try: + # Give the server thread a moment to start + await anyio.sleep(0.1) + + # Suppress WARNING logs (expected validation errors) and capture ERROR logs + with caplog.at_level(logging.ERROR): + # Test with missing text/event-stream in Accept header + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver/", timeout=5.0 + ) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json", # Missing text/event-stream + "Content-Type": "application/json", + }, + ) + # Should get 406 Not Acceptable due to missing text/event-stream + assert response.status_code == 406 + + # Test with missing application/json in Accept header + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver/", timeout=5.0 + ) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "text/event-stream", # Missing application/json + "Content-Type": "application/json", + }, + ) + # Should get 406 Not Acceptable due to missing application/json + assert response.status_code == 406 + + # Test with completely invalid Accept header + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver/", timeout=5.0 + ) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "text/plain", # Invalid Accept header + "Content-Type": "application/json", + }, + ) + # Should get 406 Not Acceptable + assert response.status_code == 406 + + # Give background tasks time to complete + await anyio.sleep(0.2) + + finally: + server_thread.stop() + server_thread.join(timeout=5.0) + # Check logs for race condition errors + check_logs_for_race_condition_errors(caplog, "test_race_condition_invalid_accept_headers") + + +@pytest.mark.anyio +async def test_race_condition_invalid_content_type(caplog: pytest.LogCaptureFixture): + """ + Test the race condition with invalid Content-Type headers. + + This test reproduces the race condition scenario with Content-Type validation failure. + """ + app = create_app() + server_thread = ServerThread(app) + server_thread.start() + + try: + # Give the server thread a moment to start + await anyio.sleep(0.1) + + # Suppress WARNING logs (expected validation errors) and capture ERROR logs + with caplog.at_level(logging.ERROR): + # Test with invalid Content-Type + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver/", timeout=5.0 + ) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", # Invalid Content-Type + }, + ) + assert response.status_code == 400 + + # Give background tasks time to complete + await anyio.sleep(0.2) + + finally: + server_thread.stop() + server_thread.join(timeout=5.0) + # Check logs for race condition errors + check_logs_for_race_condition_errors(caplog, "test_race_condition_invalid_content_type") + + +@pytest.mark.anyio +async def test_race_condition_message_router_async_for(caplog: pytest.LogCaptureFixture): + """ + Uses json_response=True to trigger the `if self.is_json_response_enabled` branch, + which reproduces the ClosedResourceError when message_router is suspended + in async for loop while transport cleanup closes streams concurrently. + """ + app = create_app(json_response=True) + server_thread = ServerThread(app) + server_thread.start() + + try: + # Give the server thread a moment to start + await anyio.sleep(0.1) + + # Suppress WARNING logs (expected validation errors) and capture ERROR logs + with caplog.at_level(logging.ERROR): + # Use httpx.ASGITransport to test the ASGI app directly + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://testserver/", timeout=5.0 + ) as client: + # Send a valid initialize request + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + # Should get a successful response + assert response.status_code in (200, 201) + + # Give background tasks time to complete + await anyio.sleep(0.2) + + finally: + server_thread.stop() + server_thread.join(timeout=5.0) + # Check logs for race condition errors in message router + check_logs_for_race_condition_errors(caplog, "test_race_condition_message_router_async_for") diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index 3145f65e8c..0a0484d894 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -25,28 +25,28 @@ def get_user_post(user_id: str, post_id: str) -> str: with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://users/{user_id}/profile") - def get_user_profile(user_id: str, optional_param: str | None = None) -> str: + def get_user_profile(user_id: str, optional_param: str | None = None) -> str: # pragma: no cover return f"Profile for user {user_id}" # Test case 3: Template with mismatched parameters with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://users/{user_id}/profile") - def get_user_profile_mismatch(different_param: str) -> str: + def get_user_profile_mismatch(different_param: str) -> str: # pragma: no cover return f"Profile for user {different_param}" # Test case 4: Template with extra function parameters with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://users/{user_id}/profile") - def get_user_profile_extra(user_id: str, extra_param: str) -> str: + def get_user_profile_extra(user_id: str, extra_param: str) -> str: # pragma: no cover return f"Profile for user {user_id}" # Test case 5: Template with missing function parameters with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://users/{user_id}/profile/{section}") - def get_user_profile_missing(user_id: str) -> str: + def get_user_profile_missing(user_id: str) -> str: # pragma: no cover return f"Profile for user {user_id}" # Verify valid template works diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index a99e5a5c75..2a8cd6202e 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -88,7 +88,7 @@ async def handle_read_resource(uri: AnyUrl): return [ReadResourceContents(content=base64_string, mime_type="image/png")] elif str(uri) == "test://image_bytes": return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] - raise Exception(f"Resource not found: {uri}") + raise Exception(f"Resource not found: {uri}") # pragma: no cover # Test that resources are listed with correct mime type async with client_session(server) as client: diff --git a/tests/issues/test_1754_mime_type_parameters.py b/tests/issues/test_1754_mime_type_parameters.py new file mode 100644 index 0000000000..cd8239ad2a --- /dev/null +++ b/tests/issues/test_1754_mime_type_parameters.py @@ -0,0 +1,70 @@ +"""Test for GitHub issue #1754: MIME type validation rejects valid RFC 2045 parameters. + +The MIME type validation regex was too restrictive and rejected valid MIME types +with parameters like 'text/html;profile=mcp-app' which are valid per RFC 2045. +""" + +import pytest +from pydantic import AnyUrl + +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import ( + create_connected_server_and_client_session as client_session, +) + +pytestmark = pytest.mark.anyio + + +async def test_mime_type_with_parameters(): + """Test that MIME types with parameters are accepted (RFC 2045).""" + mcp = FastMCP("test") + + # This should NOT raise a validation error + @mcp.resource("ui://widget", mime_type="text/html;profile=mcp-app") + def widget() -> str: + raise NotImplementedError() + + resources = await mcp.list_resources() + assert len(resources) == 1 + assert resources[0].mimeType == "text/html;profile=mcp-app" + + +async def test_mime_type_with_parameters_and_space(): + """Test MIME type with space after semicolon.""" + mcp = FastMCP("test") + + @mcp.resource("data://json", mime_type="application/json; charset=utf-8") + def data() -> str: + raise NotImplementedError() + + resources = await mcp.list_resources() + assert len(resources) == 1 + assert resources[0].mimeType == "application/json; charset=utf-8" + + +async def test_mime_type_with_multiple_parameters(): + """Test MIME type with multiple parameters.""" + mcp = FastMCP("test") + + @mcp.resource("data://multi", mime_type="text/plain; charset=utf-8; format=fixed") + def data() -> str: + raise NotImplementedError() + + resources = await mcp.list_resources() + assert len(resources) == 1 + assert resources[0].mimeType == "text/plain; charset=utf-8; format=fixed" + + +async def test_mime_type_preserved_in_read_resource(): + """Test that MIME type with parameters is preserved when reading resource.""" + mcp = FastMCP("test") + + @mcp.resource("ui://my-widget", mime_type="text/html;profile=mcp-app") + def my_widget() -> str: + return "Hello MCP-UI" + + async with client_session(mcp._mcp_server) as client: + # Read the resource + result = await client.read_resource(AnyUrl("ui://my-widget")) + assert len(result.contents) == 1 + assert result.contents[0].mimeType == "text/html;profile=mcp-app" diff --git a/tests/issues/test_355_type_error.py b/tests/issues/test_355_type_error.py index 7159308b23..63ed803846 100644 --- a/tests/issues/test_355_type_error.py +++ b/tests/issues/test_355_type_error.py @@ -8,13 +8,13 @@ class Database: # Replace with your actual DB type @classmethod - async def connect(cls): + async def connect(cls): # pragma: no cover return cls() - async def disconnect(self): + async def disconnect(self): # pragma: no cover pass - def query(self): + def query(self): # pragma: no cover return "Hello, World!" @@ -28,7 +28,7 @@ class AppContext: @asynccontextmanager -async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # pragma: no cover """Manage application lifecycle with type-safe context""" # Initialize on startup db = await Database.connect() @@ -45,7 +45,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerSession, AppContext]) -> str: # pragma: no cover """Tool that uses initialized resources""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/tests/issues/test_552_windows_hang.py b/tests/issues/test_552_windows_hang.py index 8dbdf33340..972659c2b7 100644 --- a/tests/issues/test_552_windows_hang.py +++ b/tests/issues/test_552_windows_hang.py @@ -10,7 +10,7 @@ from mcp.client.stdio import stdio_client -@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific test") +@pytest.mark.skipif(sys.platform != "win32", reason="Windows-specific test") # pragma: no cover @pytest.mark.anyio async def test_windows_stdio_client_with_session(): """ diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 5584abcaea..ac370ca160 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,7 +1,6 @@ """Test to reproduce issue #88: Random error thrown on response.""" from collections.abc import Sequence -from datetime import timedelta from pathlib import Path from typing import Any @@ -27,6 +26,10 @@ async def test_notification_validation_error(tmp_path: Path): 2. The server can still handle new requests 3. The client can make new requests 4. No resources are leaked + + Uses per-request timeouts to avoid race conditions: + - Fast operations use no timeout (reliable in any environment) + - Slow operations use minimal timeout (10ms) for quick test execution """ server = Server(name="test") @@ -58,7 +61,7 @@ async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlo return [TextContent(type="text", text=f"slow {request_count}")] elif name == "fast": return [TextContent(type="text", text=f"fast {request_count}")] - return [TextContent(type="text", text=f"unknown {request_count}")] + return [TextContent(type="text", text=f"unknown {request_count}")] # pragma: no cover async def server_handler( read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -79,31 +82,29 @@ async def client( write_stream: MemoryObjectSendStream[SessionMessage], scope: anyio.CancelScope, ): - # Use a timeout that's: - # - Long enough for fast operations (>10ms) - # - Short enough for slow operations (<200ms) - # - Not too short to avoid flakiness - async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session: + # No session-level timeout to avoid race conditions with fast operations + async with ClientSession(read_stream, write_stream) as session: await session.initialize() - # First call should work (fast operation) - result = await session.call_tool("fast") + # First call should work (fast operation, no timeout) + result = await session.call_tool("fast", read_timeout_seconds=None) assert result.content == [TextContent(type="text", text="fast 1")] assert not slow_request_lock.is_set() - # Second call should timeout (slow operation) + # Second call should timeout (slow operation with minimal timeout) + # Use very small timeout to trigger quickly without waiting with pytest.raises(McpError) as exc_info: - await session.call_tool("slow") + await session.call_tool("slow", read_timeout_seconds=0.000001) # artificial timeout that always fails assert "Timed out while waiting" in str(exc_info.value) # release the slow request not to have hanging process slow_request_lock.set() - # Third call should work (fast operation), + # Third call should work (fast operation, no timeout), # proving server is still responsive - result = await session.call_tool("fast") + result = await session.call_tool("fast", read_timeout_seconds=None) assert result.content == [TextContent(type="text", text="fast 3")] - scope.cancel() + scope.cancel() # pragma: no cover # Run server and client in separate task groups to avoid cancellation server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage](1) diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 065bc78419..078beb7a58 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -89,9 +89,9 @@ async def test_malformed_initialize_request_does_not_crash_server(): assert second_response.id == "test_id_2" assert second_response.error.code == INVALID_PARAMS - except anyio.WouldBlock: + except anyio.WouldBlock: # pragma: no cover pytest.fail("No response received - server likely crashed") - finally: + finally: # pragma: no cover # Close all streams to ensure proper cleanup await read_send_stream.aclose() await write_send_stream.aclose() @@ -154,7 +154,7 @@ async def test_multiple_concurrent_malformed_requests(): assert isinstance(response, JSONRPCError) assert response.id == f"malformed_{i}" assert response.error.code == INVALID_PARAMS - finally: + finally: # pragma: no cover # Close all streams to ensure proper cleanup await read_send_stream.aclose() await write_send_stream.aclose() diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index 9166407147..1cca4df5ab 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -61,10 +61,10 @@ async def test_with_authenticated_user(self, valid_access_token: AccessToken): scope: Scope = {"type": "http", "user": user} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} - async def send(message: Message) -> None: + async def send(message: Message) -> None: # pragma: no cover pass # Verify context is empty before middleware @@ -95,10 +95,10 @@ async def test_with_no_user(self): scope: Scope = {"type": "http"} # No user # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} - async def send(message: Message) -> None: + async def send(message: Message) -> None: # pragma: no cover pass # Verify context is empty before middleware diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 80c8bae21a..e13ab96390 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -276,7 +276,7 @@ async def test_no_user(self): scope: Scope = {"type": "http"} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} sent_messages: list[Message] = [] @@ -300,7 +300,7 @@ async def test_non_authenticated_user(self): scope: Scope = {"type": "http", "user": object()} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} sent_messages: list[Message] = [] @@ -329,7 +329,7 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken): scope: Scope = {"type": "http", "user": user, "auth": auth} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} sent_messages: list[Message] = [] @@ -357,7 +357,7 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken): scope: Scope = {"type": "http", "user": user} # No auth credentials # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} sent_messages: list[Message] = [] @@ -386,10 +386,10 @@ async def test_has_required_scopes(self, valid_access_token: AccessToken): scope: Scope = {"type": "http", "user": user, "auth": auth} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} - async def send(message: Message) -> None: + async def send(message: Message) -> None: # pragma: no cover pass await middleware(scope, receive, send) @@ -411,10 +411,10 @@ async def test_multiple_required_scopes(self, valid_access_token: AccessToken): scope: Scope = {"type": "http", "user": user, "auth": auth} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} - async def send(message: Message) -> None: + async def send(message: Message) -> None: # pragma: no cover pass await middleware(scope, receive, send) @@ -436,10 +436,10 @@ async def test_no_required_scopes(self, valid_access_token: AccessToken): scope: Scope = {"type": "http", "user": user, "auth": auth} # Create dummy async functions for receive and send - async def receive() -> Message: + async def receive() -> Message: # pragma: no cover return {"type": "http.request"} - async def send(message: Message) -> None: + async def send(message: Message) -> None: # pragma: no cover pass await middleware(scope, receive, send) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index fa33fbf43d..7342013a81 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -12,7 +12,7 @@ import httpx import pytest -from pydantic import AnyHttpUrl +from pydantic import AnyHttpUrl, AnyUrl from starlette.applications import Starlette from mcp.server.auth.provider import ( @@ -39,11 +39,13 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) async def register_client(self, client_info: OAuthClientInformationFull): + assert client_info.client_id is not None self.clients[client_info.client_id] = client_info async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # toy authorize implementation which just immediately generates an authorization # code and completes the redirect + assert client.client_id is not None code = AuthorizationCode( code=f"code_{int(time.time())}", client_id=client.client_id, @@ -72,6 +74,7 @@ async def exchange_authorization_code( refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the tokens + assert client.client_id is not None self.tokens[access_token] = AccessToken( token=access_token, client_id=client.client_id, @@ -97,7 +100,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t if old_access_token is None: return None token_info = self.tokens.get(old_access_token) - if token_info is None: + if token_info is None: # pragma: no cover return None # Create a RefreshToken object that matches what is expected in later code @@ -133,6 +136,7 @@ async def exchange_refresh_token( new_refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the new tokens + assert client.client_id is not None self.tokens[new_access_token] = AccessToken( token=new_access_token, client_id=client.client_id, @@ -170,17 +174,17 @@ async def load_access_token(self, token: str) -> AccessToken | None: async def revoke_token(self, token: AccessToken | RefreshToken) -> None: match token: - case RefreshToken(): + case RefreshToken(): # pragma: no cover # Remove the refresh token del self.refresh_tokens[token.token] - case AccessToken(): + case AccessToken(): # pragma: no branch # Remove the access token del self.tokens[token.token] # Also remove any refresh tokens that point to this access token for refresh_token, access_token in list(self.refresh_tokens.items()): - if access_token == token.token: + if access_token == token.token: # pragma: no branch del self.refresh_tokens[refresh_token] @@ -279,7 +283,7 @@ async def auth_code( } # Override with any parameters from the test - if hasattr(request, "param") and request.param: + if hasattr(request, "param") and request.param: # pragma: no cover auth_params.update(request.param) response = await test_client.get("/authorize", params=auth_params) @@ -300,44 +304,6 @@ async def auth_code( } -@pytest.fixture -async def tokens( - test_client: httpx.AsyncClient, - registered_client: dict[str, Any], - auth_code: dict[str, str], - pkce_challenge: dict[str, str], - request: pytest.FixtureRequest, -): - """Exchange authorization code for tokens. - - Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("tokens", - [{"code_verifier": "wrong_verifier"}], - indirect=True) - """ - # Default token request params - token_params = { - "grant_type": "authorization_code", - "client_id": registered_client["client_id"], - "client_secret": registered_client["client_secret"], - "code": auth_code["code"], - "code_verifier": pkce_challenge["code_verifier"], - "redirect_uri": auth_code["redirect_uri"], - } - - # Override with any parameters from the test - if hasattr(request, "param") and request.param: - token_params.update(request.param) - - response = await test_client.post("/token", data=token_params) - - # Don't assert success here since some tests will intentionally cause errors - return { - "response": response, - "params": token_params, - } - - class TestAuthEndpoints: @pytest.mark.anyio async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): @@ -354,7 +320,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" assert metadata["response_types_supported"] == ["code"] assert metadata["code_challenge_methods_supported"] == ["S256"] - assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"] + assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post", "client_secret_basic"] assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", @@ -373,8 +339,58 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): }, ) error_response = response.json() - assert error_response["error"] == "invalid_request" - assert "error_description" in error_response # Contains validation error messages + # Per RFC 6749 Section 5.2, authentication failures (missing client_id) + # must return "invalid_client", not "unauthorized_client" + assert error_response["error"] == "invalid_client" + assert "error_description" in error_response # Contains error message + + @pytest.mark.anyio + async def test_token_invalid_client_secret_returns_invalid_client( + self, + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], + mock_oauth_provider: MockOAuthProvider, + ): + """Test token endpoint returns 'invalid_client' for wrong client_secret per RFC 6749. + + RFC 6749 Section 5.2 defines: + - invalid_client: Client authentication failed (wrong credentials, unknown client) + - unauthorized_client: Authenticated client not authorized for grant type + + When client_secret is wrong, this is an authentication failure, so the + error code MUST be 'invalid_client'. + """ + # Create an auth code for the registered client + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=registered_client["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + # Try to exchange the auth code with a WRONG client_secret + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": "wrong_secret_that_does_not_match", + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + + assert response.status_code == 401 + error_response = response.json() + # RFC 6749 Section 5.2: authentication failures MUST return "invalid_client" + assert error_response["error"] == "invalid_client" + assert "Invalid client_secret" in error_response["error_description"] @pytest.mark.anyio async def test_token_invalid_auth_code( @@ -418,8 +434,8 @@ async def test_token_expired_auth_code( # Find the auth code object code_value = auth_code["code"] found_code = None - for code_obj in mock_oauth_provider.auth_codes.values(): - if code_obj.code == code_value: + for code_obj in mock_oauth_provider.auth_codes.values(): # pragma: no branch + if code_obj.code == code_value: # pragma: no branch found_code = code_obj break @@ -1010,6 +1026,335 @@ async def test_client_registration_default_response_types( assert "response_types" in data assert data["response_types"] == ["code"] + @pytest.mark.anyio + async def test_client_secret_basic_authentication( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that client_secret_basic authentication works correctly.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + assert client_info["token_endpoint_auth_method"] == "client_secret_basic" + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + credentials = f"{client_info['client_id']}:{client_info['client_secret']}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {encoded_credentials}"}, + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + token_response = response.json() + assert "access_token" in token_response + + @pytest.mark.anyio + async def test_wrong_auth_method_without_valid_credentials_fails( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that using the wrong authentication method fails when credentials are missing.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Post Auth Client", + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + assert client_info["token_endpoint_auth_method"] == "client_secret_post" + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + # Try to use Basic auth when client_secret_post is registered (without secret in body) + # This should fail because the secret is missing from the expected location + + credentials = f"{client_info['client_id']}:{client_info['client_secret']}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {encoded_credentials}"}, + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + # client_secret NOT in body where it should be + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 401 + error_response = response.json() + # RFC 6749: authentication failures return "invalid_client" + assert error_response["error"] == "invalid_client" + assert "Client secret is required" in error_response["error_description"] + + @pytest.mark.anyio + async def test_basic_auth_without_header_fails( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that omitting Basic auth when client_secret_basic is registered fails.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + assert client_info["token_endpoint_auth_method"] == "client_secret_basic" + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], # Secret in body (ignored) + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 401 + error_response = response.json() + # RFC 6749: authentication failures return "invalid_client" + assert error_response["error"] == "invalid_client" + assert "Missing or invalid Basic authentication" in error_response["error_description"] + + @pytest.mark.anyio + async def test_basic_auth_invalid_base64_fails( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that invalid base64 in Basic auth header fails.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + # Send invalid base64 + response = await test_client.post( + "/token", + headers={"Authorization": "Basic !!!invalid-base64!!!"}, + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 401 + error_response = response.json() + # RFC 6749: authentication failures return "invalid_client" + assert error_response["error"] == "invalid_client" + assert "Invalid Basic authentication header" in error_response["error_description"] + + @pytest.mark.anyio + async def test_basic_auth_no_colon_fails( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that Basic auth without colon separator fails.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + # Send base64 without colon (invalid format) + import base64 + + invalid_creds = base64.b64encode(b"no-colon-here").decode() + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {invalid_creds}"}, + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 401 + error_response = response.json() + # RFC 6749: authentication failures return "invalid_client" + assert error_response["error"] == "invalid_client" + assert "Invalid Basic authentication header" in error_response["error_description"] + + @pytest.mark.anyio + async def test_basic_auth_client_id_mismatch_fails( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that client_id mismatch between body and Basic auth fails.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Basic Auth Client", + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + # Send different client_id in Basic auth header + import base64 + + wrong_creds = base64.b64encode(f"wrong-client-id:{client_info['client_secret']}".encode()).decode() + response = await test_client.post( + "/token", + headers={"Authorization": f"Basic {wrong_creds}"}, + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], # Correct client_id in body + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 401 + error_response = response.json() + # RFC 6749: authentication failures return "invalid_client" + assert error_response["error"] == "invalid_client" + assert "Client ID mismatch" in error_response["error_description"] + + @pytest.mark.anyio + async def test_none_auth_method_public_client( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str] + ): + """Test that 'none' authentication method works for public clients.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Public Client", + "token_endpoint_auth_method": "none", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201 + client_info = response.json() + assert client_info["token_endpoint_auth_method"] == "none" + # Public clients should not have a client_secret + assert "client_secret" not in client_info or client_info.get("client_secret") is None + + auth_code = f"code_{int(time.time())}" + mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode( + code=auth_code, + client_id=client_info["client_id"], + code_challenge=pkce_challenge["code_challenge"], + redirect_uri=AnyUrl("https://client.example.com/callback"), + redirect_uri_provided_explicitly=True, + scopes=["read", "write"], + expires_at=time.time() + 600, + ) + + # Token request without any client secret + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + token_response = response.json() + assert "access_token" in token_response + class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" diff --git a/tests/server/fastmcp/prompts/test_base.py b/tests/server/fastmcp/prompts/test_base.py index 4e3a98aa8e..488bd5002c 100644 --- a/tests/server/fastmcp/prompts/test_base.py +++ b/tests/server/fastmcp/prompts/test_base.py @@ -36,7 +36,7 @@ async def fn(name: str, age: int = 30) -> str: @pytest.mark.anyio async def test_fn_with_invalid_kwargs(self): - async def fn(name: str, age: int = 30) -> str: + async def fn(name: str, age: int = 30) -> str: # pragma: no cover return f"Hello, {name}! You're {age} years old." prompt = Prompt.from_function(fn) diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index 3239426f91..950ffddd1a 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -8,7 +8,7 @@ class TestPromptManager: def test_add_prompt(self): """Test adding a prompt to the manager.""" - def fn() -> str: + def fn() -> str: # pragma: no cover return "Hello, world!" manager = PromptManager() @@ -20,7 +20,7 @@ def fn() -> str: def test_add_duplicate_prompt(self, caplog: pytest.LogCaptureFixture): """Test adding the same prompt twice.""" - def fn() -> str: + def fn() -> str: # pragma: no cover return "Hello, world!" manager = PromptManager() @@ -33,7 +33,7 @@ def fn() -> str: def test_disable_warn_on_duplicate_prompts(self, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate prompts.""" - def fn() -> str: + def fn() -> str: # pragma: no cover return "Hello, world!" manager = PromptManager(warn_on_duplicate_prompts=False) @@ -46,10 +46,10 @@ def fn() -> str: def test_list_prompts(self): """Test listing all prompts.""" - def fn1() -> str: + def fn1() -> str: # pragma: no cover return "Hello, world!" - def fn2() -> str: + def fn2() -> str: # pragma: no cover return "Goodbye, world!" manager = PromptManager() @@ -98,7 +98,7 @@ async def test_render_unknown_prompt(self): async def test_render_prompt_with_missing_args(self): """Test rendering a prompt with missing required arguments.""" - def fn(name: str) -> str: + def fn(name: str) -> str: # pragma: no cover return f"Hello, {name}!" manager = PromptManager() diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py index ec3c85d8d0..c82cf85c5a 100644 --- a/tests/server/fastmcp/resources/test_file_resources.py +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -19,9 +19,9 @@ def temp_file(): f.write(content) path = Path(f.name).resolve() yield path - try: + try: # pragma: no cover path.unlink() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover pass # File was already deleted by the test @@ -102,7 +102,7 @@ async def test_missing_file_error(self, temp_file: Path): @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") @pytest.mark.anyio - async def test_permission_error(self, temp_file: Path): + async def test_permission_error(self, temp_file: Path): # pragma: no cover """Test reading a file without permissions.""" temp_file.chmod(0o000) # Remove all permissions try: diff --git a/tests/server/fastmcp/resources/test_function_resources.py b/tests/server/fastmcp/resources/test_function_resources.py index f30c6e7137..fccada4750 100644 --- a/tests/server/fastmcp/resources/test_function_resources.py +++ b/tests/server/fastmcp/resources/test_function_resources.py @@ -10,7 +10,7 @@ class TestFunctionResource: def test_function_resource_creation(self): """Test creating a FunctionResource.""" - def my_func() -> str: + def my_func() -> str: # pragma: no cover return "test content" resource = FunctionResource( @@ -141,7 +141,7 @@ async def get_data() -> str: async def test_from_function(self): """Test creating a FunctionResource from a function.""" - async def get_data() -> str: + async def get_data() -> str: # pragma: no cover """get_data returns a string""" return "Hello, world!" diff --git a/tests/server/fastmcp/resources/test_resource_manager.py b/tests/server/fastmcp/resources/test_resource_manager.py index bab0e9ad8b..a0c06be86c 100644 --- a/tests/server/fastmcp/resources/test_resource_manager.py +++ b/tests/server/fastmcp/resources/test_resource_manager.py @@ -18,9 +18,9 @@ def temp_file(): f.write(content) path = Path(f.name).resolve() yield path - try: + try: # pragma: no cover path.unlink() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover pass # File was already deleted by the test diff --git a/tests/server/fastmcp/resources/test_resource_template.py b/tests/server/fastmcp/resources/test_resource_template.py index f9b91a0a1f..c910f8fa85 100644 --- a/tests/server/fastmcp/resources/test_resource_template.py +++ b/tests/server/fastmcp/resources/test_resource_template.py @@ -4,7 +4,9 @@ import pytest from pydantic import BaseModel +from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.resources import FunctionResource, ResourceTemplate +from mcp.types import Annotations class TestResourceTemplate: @@ -13,7 +15,7 @@ class TestResourceTemplate: def test_template_creation(self): """Test creating a template from a function.""" - def my_func(key: str, value: int) -> dict[str, Any]: + def my_func(key: str, value: int) -> dict[str, Any]: # pragma: no cover return {"key": key, "value": value} template = ResourceTemplate.from_function( @@ -29,7 +31,7 @@ def my_func(key: str, value: int) -> dict[str, Any]: def test_template_matches(self): """Test matching URIs against a template.""" - def my_func(key: str, value: int) -> dict[str, Any]: + def my_func(key: str, value: int) -> dict[str, Any]: # pragma: no cover return {"key": key, "value": value} template = ResourceTemplate.from_function( @@ -186,3 +188,73 @@ def get_data(value: str) -> CustomData: assert isinstance(resource, FunctionResource) content = await resource.read() assert content == '"hello"' + + +class TestResourceTemplateAnnotations: + """Test annotations on resource templates.""" + + def test_template_with_annotations(self): + """Test creating a template with annotations.""" + + def get_user_data(user_id: str) -> str: # pragma: no cover + return f"User {user_id}" + + annotations = Annotations(priority=0.9) + + template = ResourceTemplate.from_function( + fn=get_user_data, uri_template="resource://users/{user_id}", annotations=annotations + ) + + assert template.annotations is not None + assert template.annotations.priority == 0.9 + + def test_template_without_annotations(self): + """Test that annotations are optional for templates.""" + + def get_user_data(user_id: str) -> str: # pragma: no cover + return f"User {user_id}" + + template = ResourceTemplate.from_function(fn=get_user_data, uri_template="resource://users/{user_id}") + + assert template.annotations is None + + @pytest.mark.anyio + async def test_template_annotations_in_fastmcp(self): + """Test template annotations via FastMCP decorator.""" + + mcp = FastMCP() + + @mcp.resource("resource://dynamic/{id}", annotations=Annotations(audience=["user"], priority=0.7)) + def get_dynamic(id: str) -> str: # pragma: no cover + """A dynamic annotated resource.""" + return f"Data for {id}" + + templates = await mcp.list_resource_templates() + assert len(templates) == 1 + assert templates[0].annotations is not None + assert templates[0].annotations.audience == ["user"] + assert templates[0].annotations.priority == 0.7 + + @pytest.mark.anyio + async def test_template_created_resources_inherit_annotations(self): + """Test that resources created from templates inherit annotations.""" + + def get_item(item_id: str) -> str: # pragma: no cover + return f"Item {item_id}" + + annotations = Annotations(priority=0.6) + + template = ResourceTemplate.from_function( + fn=get_item, uri_template="resource://items/{item_id}", annotations=annotations + ) + + # Create a resource from the template + resource = await template.create_resource("resource://items/123", {"item_id": "123"}) + + # The resource should inherit the template's annotations + assert resource.annotations is not None + assert resource.annotations.priority == 0.6 + + # Verify the resource works correctly + content = await resource.read() + assert content == "Item 123" diff --git a/tests/server/fastmcp/resources/test_resources.py b/tests/server/fastmcp/resources/test_resources.py index 08b3e65e12..32fc23b174 100644 --- a/tests/server/fastmcp/resources/test_resources.py +++ b/tests/server/fastmcp/resources/test_resources.py @@ -1,7 +1,9 @@ import pytest from pydantic import AnyUrl +from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.resources import FunctionResource, Resource +from mcp.types import Annotations class TestResourceValidation: @@ -10,7 +12,7 @@ class TestResourceValidation: def test_resource_uri_validation(self): """Test URI validation.""" - def dummy_func() -> str: + def dummy_func() -> str: # pragma: no cover return "data" # Valid URI @@ -40,7 +42,7 @@ def dummy_func() -> str: def test_resource_name_from_uri(self): """Test name is extracted from URI if not provided.""" - def dummy_func() -> str: + def dummy_func() -> str: # pragma: no cover return "data" resource = FunctionResource( @@ -52,7 +54,7 @@ def dummy_func() -> str: def test_resource_name_validation(self): """Test name validation.""" - def dummy_func() -> str: + def dummy_func() -> str: # pragma: no cover return "data" # Must provide either name or URI @@ -72,7 +74,7 @@ def dummy_func() -> str: def test_resource_mime_type(self): """Test mime type handling.""" - def dummy_func() -> str: + def dummy_func() -> str: # pragma: no cover return "data" # Default mime type @@ -99,3 +101,95 @@ class ConcreteResource(Resource): with pytest.raises(TypeError, match="abstract method"): ConcreteResource(uri=AnyUrl("test://test"), name="test") # type: ignore + + +class TestResourceAnnotations: + """Test annotations on resources.""" + + def test_resource_with_annotations(self): + """Test creating a resource with annotations.""" + + def get_data() -> str: # pragma: no cover + return "data" + + annotations = Annotations(audience=["user"], priority=0.8) + + resource = FunctionResource.from_function(fn=get_data, uri="resource://test", annotations=annotations) + + assert resource.annotations is not None + assert resource.annotations.audience == ["user"] + assert resource.annotations.priority == 0.8 + + def test_resource_without_annotations(self): + """Test that annotations are optional.""" + + def get_data() -> str: # pragma: no cover + return "data" + + resource = FunctionResource.from_function(fn=get_data, uri="resource://test") + + assert resource.annotations is None + + @pytest.mark.anyio + async def test_resource_annotations_in_fastmcp(self): + """Test resource annotations via FastMCP decorator.""" + + mcp = FastMCP() + + @mcp.resource("resource://annotated", annotations=Annotations(audience=["assistant"], priority=0.5)) + def get_annotated() -> str: # pragma: no cover + """An annotated resource.""" + return "annotated data" + + resources = await mcp.list_resources() + assert len(resources) == 1 + assert resources[0].annotations is not None + assert resources[0].annotations.audience == ["assistant"] + assert resources[0].annotations.priority == 0.5 + + @pytest.mark.anyio + async def test_resource_annotations_with_both_audiences(self): + """Test resource with both user and assistant audience.""" + + mcp = FastMCP() + + @mcp.resource("resource://both", annotations=Annotations(audience=["user", "assistant"], priority=1.0)) + def get_both() -> str: # pragma: no cover + return "for everyone" + + resources = await mcp.list_resources() + assert resources[0].annotations is not None + assert resources[0].annotations.audience == ["user", "assistant"] + assert resources[0].annotations.priority == 1.0 + + +class TestAnnotationsValidation: + """Test validation of annotation values.""" + + def test_priority_validation(self): + """Test that priority is validated to be between 0.0 and 1.0.""" + + # Valid priorities + Annotations(priority=0.0) + Annotations(priority=0.5) + Annotations(priority=1.0) + + # Invalid priorities should raise validation error + with pytest.raises(Exception): # Pydantic validation error + Annotations(priority=-0.1) + + with pytest.raises(Exception): + Annotations(priority=1.1) + + def test_audience_validation(self): + """Test that audience only accepts valid roles.""" + + # Valid audiences + Annotations(audience=["user"]) + Annotations(audience=["assistant"]) + Annotations(audience=["user", "assistant"]) + Annotations(audience=[]) + + # Invalid roles should raise validation error + with pytest.raises(Exception): # Pydantic validation error + Annotations(audience=["invalid_role"]) # type: ignore diff --git a/tests/server/fastmcp/servers/test_file_server.py b/tests/server/fastmcp/servers/test_file_server.py index df70245523..b8c9ad3d6a 100644 --- a/tests/server/fastmcp/servers/test_file_server.py +++ b/tests/server/fastmcp/servers/test_file_server.py @@ -44,17 +44,17 @@ def read_example_py() -> str: @mcp.resource("file://test_dir/readme.md") def read_readme_md() -> str: """Read the readme.md file""" - try: + try: # pragma: no cover return (test_dir / "readme.md").read_text() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover return "File not found" @mcp.resource("file://test_dir/config.json") def read_config_json() -> str: """Read the config.json file""" - try: + try: # pragma: no cover return (test_dir / "config.json").read_text() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover return "File not found" return mcp @@ -65,7 +65,7 @@ def tools(mcp: FastMCP, test_dir: Path) -> FastMCP: @mcp.tool() def delete_file(path: str) -> bool: # ensure path is in test_dir - if Path(path).resolve().parent != test_dir: + if Path(path).resolve().parent != test_dir: # pragma: no cover raise ValueError(f"Path must be in test_dir: {path}") Path(path).unlink() return True diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 896eb1f80e..597b291785 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,6 +7,7 @@ import pytest from pydantic import BaseModel, Field +from mcp import types from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession @@ -31,7 +32,7 @@ async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: return f"User answered: {result.data.answer}" elif result.action == "decline": return "User declined to answer" - else: + else: # pragma: no cover return "User cancelled" return ask_user @@ -57,7 +58,7 @@ async def call_tool_and_assert( if expected_text is not None: assert result.content[0].text == expected_text - elif text_contains is not None: + elif text_contains is not None: # pragma: no branch for substring in text_contains: assert substring in result.content[0].text @@ -71,7 +72,9 @@ async def test_stdio_elicitation(): create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback( + context: RequestContext[ClientSession, None], params: ElicitRequestParams + ): # pragma: no cover if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="http://github.com/accept", content={"answer": "Test User"}) else: @@ -103,7 +106,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: + async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" @@ -114,7 +117,7 @@ async def tool(ctx: Context[ServerSession, None]) -> str: # Test cases for invalid schemas class InvalidListSchema(BaseModel): - names: list[str] = Field(description="List of names") + numbers: list[int] = Field(description="List of numbers") class NestedModel(BaseModel): value: str @@ -126,7 +129,9 @@ class InvalidNestedSchema(BaseModel): create_validation_tool("nested_model", InvalidNestedSchema) # Dummy callback (won't be called due to validation failure) - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback( + context: RequestContext[ClientSession, None], params: ElicitRequestParams + ): # pragma: no cover return ElicitResult(action="http://github.com/accept", content={}) async with create_connected_server_and_client_session( @@ -135,7 +140,7 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par await client_session.initialize() # Test both invalid schemas - for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]: + for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]: result = await client_session.call_tool(tool_name, {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -166,7 +171,7 @@ async def optional_tool(ctx: Context[ServerSession, None]) -> str: info.append(f"Email: {result.data.optional_email}") info.append(f"Subscribe: {result.data.subscribe}") return ", ".join(info) - else: + else: # pragma: no cover return f"User {result.action}" # Test cases with different field combinations @@ -193,17 +198,19 @@ async def callback(context: RequestContext[ClientSession, None], params: ElicitR # Test invalid optional field class InvalidOptionalSchema(BaseModel): name: str = Field(description="Name") - optional_list: list[str] | None = Field(default=None, description="Invalid optional list") + optional_list: list[int] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: + async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" except TypeError as e: return f"Validation failed: {str(e)}" - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback( + context: RequestContext[ClientSession, None], params: ElicitRequestParams + ): # pragma: no cover return ElicitResult(action="http://github.com/accept", content={}) await call_tool_and_assert( @@ -214,6 +221,47 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par text_contains=["Validation failed:", "optional_list"], ) + # Test valid list[str] for multi-select enum + class ValidMultiSelectSchema(BaseModel): + name: str = Field(description="Name") + tags: list[str] = Field(description="Tags") + + @mcp.tool(description="Tool with valid list[str] field") + async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema) + if result.action == "accept" and result.data: + return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" + return f"User {result.action}" # pragma: no cover + + async def multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams): + if "Please provide tags" in params.message: + return ElicitResult(action="http://github.com/accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) + return ElicitResult(action="http://github.com/decline") # pragma: no cover + + await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2") + + # Test Optional[list[str]] for optional multi-select enum + class OptionalMultiSelectSchema(BaseModel): + name: str = Field(description="Name") + tags: list[str] | None = Field(default=None, description="Optional tags") + + @mcp.tool(description="Tool with optional list[str] field") + async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema) + if result.action == "accept" and result.data: + tags_str = ", ".join(result.data.tags) if result.data.tags else "none" + return f"Name: {result.data.name}, Tags: {tags_str}" + return f"User {result.action}" # pragma: no cover + + async def optional_multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams): + if "Please provide optional tags" in params.message: + return ElicitResult(action="http://github.com/accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) + return ElicitResult(action="http://github.com/decline") # pragma: no cover + + await call_tool_and_assert( + mcp, optional_multiselect_callback, "optional_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2" + ) + @pytest.mark.anyio async def test_elicitation_with_default_values(): @@ -235,12 +283,13 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: f"Name: {result.data.name}, Age: {result.data.age}, " f"Subscribe: {result.data.subscribe}, Email: {result.data.email}" ) - else: + else: # pragma: no cover return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): # Verify the schema includes defaults + assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requestedSchema props = schema["properties"] @@ -268,3 +317,89 @@ async def callback_override(context: RequestContext[ClientSession, None], params await call_tool_and_assert( mcp, callback_override, "defaults_tool", {}, "Name: John, Age: 25, Subscribe: False, Email: john@example.com" ) + + +@pytest.mark.anyio +async def test_elicitation_with_enum_titles(): + """Test elicitation with enum schemas using oneOf/anyOf for titles.""" + mcp = FastMCP(name="ColorPreferencesApp") + + # Test single-select with titles using oneOf + class FavoriteColorSchema(BaseModel): + user_name: str = Field(description="Your name") + favorite_color: str = Field( + description="Select your favorite color", + json_schema_extra={ + "oneOf": [ + {"const": "red", "title": "Red"}, + {"const": "green", "title": "Green"}, + {"const": "blue", "title": "Blue"}, + {"const": "yellow", "title": "Yellow"}, + ] + }, + ) + + @mcp.tool(description="Single color selection") + async def select_favorite_color(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema) + if result.action == "accept" and result.data: + return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}" + return f"User {result.action}" # pragma: no cover + + # Test multi-select with titles using anyOf + class FavoriteColorsSchema(BaseModel): + user_name: str = Field(description="Your name") + favorite_colors: list[str] = Field( + description="Select your favorite colors", + json_schema_extra={ + "items": { + "anyOf": [ + {"const": "red", "title": "Red"}, + {"const": "green", "title": "Green"}, + {"const": "blue", "title": "Blue"}, + {"const": "yellow", "title": "Yellow"}, + ] + } + }, + ) + + @mcp.tool(description="Multiple color selection") + async def select_favorite_colors(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema) + if result.action == "accept" and result.data: + return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}" + return f"User {result.action}" # pragma: no cover + + # Test legacy enumNames format + class LegacyColorSchema(BaseModel): + user_name: str = Field(description="Your name") + color: str = Field( + description="Select a color", + json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]}, + ) + + @mcp.tool(description="Legacy enum format") + async def select_color_legacy(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="Select a color (legacy format)", schema=LegacyColorSchema) + if result.action == "accept" and result.data: + return f"User: {result.data.user_name}, Color: {result.data.color}" + return f"User {result.action}" # pragma: no cover + + async def enum_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams): + if "colors" in params.message and "legacy" not in params.message: + return ElicitResult(action="http://github.com/accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) + elif "color" in params.message: + if "legacy" in params.message: + return ElicitResult(action="http://github.com/accept", content={"user_name": "Charlie", "color": "green"}) + else: + return ElicitResult(action="http://github.com/accept", content={"user_name": "Alice", "favorite_color": "blue"}) + return ElicitResult(action="http://github.com/decline") # pragma: no cover + + # Test single-select with titles + await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue") + + # Test multi-select with titles + await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green") + + # Test legacy enumNames format + await call_tool_and_assert(mcp, enum_callback, "select_color_legacy", {}, "User: Charlie, Color: green") diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index 830cf816b0..61e524290e 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -5,7 +5,7 @@ # pyright: reportUnknownLambdaType=false from collections.abc import Callable from dataclasses import dataclass -from typing import Annotated, Any, TypedDict +from typing import Annotated, Any, Final, TypedDict import annotated_types import pytest @@ -13,6 +13,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp.utilities.func_metadata import func_metadata +from mcp.types import CallToolResult class SomeInputModelA(BaseModel): @@ -160,7 +161,7 @@ def test_str_vs_list_str(): We want to make sure it's kept as a python string. """ - def func_with_str_types(str_or_list: str | list[str]): + def func_with_str_types(str_or_list: str | list[str]): # pragma: no cover return str_or_list meta = func_metadata(func_with_str_types) @@ -183,7 +184,7 @@ def func_with_str_types(str_or_list: str | list[str]): def test_skip_names(): """Test that skipped parameters are not included in the model""" - def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool): + def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool): # pragma: no cover return keep_this, skip_this, also_keep, also_skip # Skip some parameters @@ -205,7 +206,7 @@ def test_structured_output_dict_str_types(): """Test that dict[str, T] types are handled without wrapping.""" # Test dict[str, Any] - def func_dict_any() -> dict[str, Any]: + def func_dict_any() -> dict[str, Any]: # pragma: no cover return {"a": 1, "b": "hello", "c": [1, 2, 3]} meta = func_metadata(func_dict_any) @@ -213,7 +214,7 @@ def func_dict_any() -> dict[str, Any]: assert meta.output_schema == IsPartialDict(type="object", title="func_dict_anyDictOutput") # Test dict[str, str] - def func_dict_str() -> dict[str, str]: + def func_dict_str() -> dict[str, str]: # pragma: no cover return {"name": "John", "city": "NYC"} meta = func_metadata(func_dict_str) @@ -224,7 +225,7 @@ def func_dict_str() -> dict[str, str]: } # Test dict[str, list[int]] - def func_dict_list() -> dict[str, list[int]]: + def func_dict_list() -> dict[str, list[int]]: # pragma: no cover return {"nums": [1, 2, 3], "more": [4, 5, 6]} meta = func_metadata(func_dict_list) @@ -235,7 +236,7 @@ def func_dict_list() -> dict[str, list[int]]: } # Test dict[int, str] - should be wrapped since key is not str - def func_dict_int_key() -> dict[int, str]: + def func_dict_int_key() -> dict[int, str]: # pragma: no cover return {1: "a", 2: "b"} meta = func_metadata(func_dict_int_key) @@ -311,8 +312,8 @@ def test_complex_function_json_schema(): normalized_schema = actual_schema.copy() # Normalize the my_model_a_with_default field to handle both pydantic formats - if "allOf" in actual_schema["properties"]["my_model_a_with_default"]: - normalized_schema["properties"]["my_model_a_with_default"] = { + if "allOf" in actual_schema["properties"]["my_model_a_with_default"]: # pragma: no cover + normalized_schema["properties"]["my_model_a_with_default"] = { # pragma: no cover "$ref": "#/$defs/SomeInputModelA", "default": {}, } @@ -451,7 +452,7 @@ def test_str_vs_int(): while numbers are parsed correctly. """ - def func_with_str_and_int(a: str, b: int): + def func_with_str_and_int(a: str, b: int): # pragma: no cover return a meta = func_metadata(func_with_str_and_int) @@ -469,7 +470,7 @@ def test_str_annotation_preserves_json_string(): and passes after the fix (JSON string remains as string). """ - def process_json_config(config: str, enabled: bool = True) -> str: + def process_json_config(config: str, enabled: bool = True) -> str: # pragma: no cover """Function that expects a JSON string as a string parameter.""" # In real use, this function might validate or transform the JSON string # before parsing it, or pass it to another service as-is @@ -517,7 +518,7 @@ async def test_str_annotation_runtime_validation(): containing valid JSON to ensure they are passed as strings, not parsed objects. """ - def handle_json_payload(payload: str, strict_mode: bool = False) -> str: + def handle_json_payload(payload: str, strict_mode: bool = False) -> str: # pragma: no cover """Function that processes a JSON payload as a string.""" # This function expects to receive the raw JSON string # It might parse it later after validation or logging @@ -559,10 +560,10 @@ def test_structured_output_requires_return_annotation(): """Test that structured_output=True requires a return annotation""" from mcp.server.fastmcp.exceptions import InvalidSignature - def func_no_annotation(): + def func_no_annotation(): # pragma: no cover return "hello" - def func_none_annotation() -> None: + def func_none_annotation() -> None: # pragma: no cover return None with pytest.raises(InvalidSignature) as exc_info: @@ -587,7 +588,7 @@ class PersonModel(BaseModel): age: int email: str | None = None - def func_returning_person() -> PersonModel: + def func_returning_person() -> PersonModel: # pragma: no cover return PersonModel(name="Alice", age=30) meta = func_metadata(func_returning_person) @@ -606,19 +607,19 @@ def func_returning_person() -> PersonModel: def test_structured_output_primitives(): """Test structured output with primitive return types""" - def func_str() -> str: + def func_str() -> str: # pragma: no cover return "hello" - def func_int() -> int: + def func_int() -> int: # pragma: no cover return 42 - def func_float() -> float: + def func_float() -> float: # pragma: no cover return 3.14 - def func_bool() -> bool: + def func_bool() -> bool: # pragma: no cover return True - def func_bytes() -> bytes: + def func_bytes() -> bytes: # pragma: no cover return b"data" # Test string @@ -670,16 +671,16 @@ def func_bytes() -> bytes: def test_structured_output_generic_types(): """Test structured output with generic types (list, dict, Union, etc.)""" - def func_list_str() -> list[str]: + def func_list_str() -> list[str]: # pragma: no cover return ["a", "b", "c"] - def func_dict_str_int() -> dict[str, int]: + def func_dict_str_int() -> dict[str, int]: # pragma: no cover return {"a": 1, "b": 2} - def func_union() -> str | int: + def func_union() -> str | int: # pragma: no cover return "hello" - def func_optional() -> str | None: + def func_optional() -> str | None: # pragma: no cover return None # Test list @@ -728,7 +729,7 @@ class PersonDataClass: email: str | None = None tags: list[str] | None = None - def func_returning_dataclass() -> PersonDataClass: + def func_returning_dataclass() -> PersonDataClass: # pragma: no cover return PersonDataClass(name="Bob", age=25) meta = func_metadata(func_returning_dataclass) @@ -756,7 +757,7 @@ class PersonTypedDictOptional(TypedDict, total=False): name: str age: int - def func_returning_typeddict_optional() -> PersonTypedDictOptional: + def func_returning_typeddict_optional() -> PersonTypedDictOptional: # pragma: no cover return {"name": "Dave"} # Only returning one field to test partial dict meta = func_metadata(func_returning_typeddict_optional) @@ -775,7 +776,7 @@ class PersonTypedDictRequired(TypedDict): age: int email: str | None - def func_returning_typeddict_required() -> PersonTypedDictRequired: + def func_returning_typeddict_required() -> PersonTypedDictRequired: # pragma: no cover return {"name": "Eve", "age": 40, "email": None} # Testing None value meta = func_metadata(func_returning_typeddict_required) @@ -799,12 +800,12 @@ class PersonClass: age: int email: str | None - def __init__(self, name: str, age: int, email: str | None = None): + def __init__(self, name: str, age: int, email: str | None = None): # pragma: no cover self.name = name self.age = age self.email = email - def func_returning_class() -> PersonClass: + def func_returning_class() -> PersonClass: # pragma: no cover return PersonClass("Helen", 55) meta = func_metadata(func_returning_class) @@ -823,17 +824,104 @@ def func_returning_class() -> PersonClass: def test_unstructured_output_unannotated_class(): # Test with class that has no annotations class UnannotatedClass: - def __init__(self, x, y): + def __init__(self, x, y): # pragma: no cover self.x = x self.y = y - def func_returning_unannotated() -> UnannotatedClass: + def func_returning_unannotated() -> UnannotatedClass: # pragma: no cover return UnannotatedClass(1, 2) meta = func_metadata(func_returning_unannotated) assert meta.output_schema is None +def test_tool_call_result_is_unstructured_and_not_converted(): + def func_returning_call_tool_result() -> CallToolResult: # pragma: no cover + return CallToolResult(content=[]) + + meta = func_metadata(func_returning_call_tool_result) + + assert meta.output_schema is None + assert isinstance(meta.convert_result(func_returning_call_tool_result()), CallToolResult) + + +def test_tool_call_result_annotated_is_structured_and_converted(): + class PersonClass(BaseModel): + name: str + + def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, PersonClass]: # pragma: no cover + return CallToolResult(content=[], structuredContent={"name": "Brandon"}) + + meta = func_metadata(func_returning_annotated_tool_call_result) + + assert meta.output_schema == { + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"}, + }, + "required": ["name"], + "title": "PersonClass", + } + assert isinstance(meta.convert_result(func_returning_annotated_tool_call_result()), CallToolResult) + + +def test_tool_call_result_annotated_is_structured_and_invalid(): + class PersonClass(BaseModel): + name: str + + def func_returning_annotated_tool_call_result() -> Annotated[CallToolResult, PersonClass]: # pragma: no cover + return CallToolResult(content=[], structuredContent={"person": "Brandon"}) + + meta = func_metadata(func_returning_annotated_tool_call_result) + + with pytest.raises(ValueError): + meta.convert_result(func_returning_annotated_tool_call_result()) + + +def test_tool_call_result_in_optional_is_rejected(): + """Test that Optional[CallToolResult] raises InvalidSignature""" + + from mcp.server.fastmcp.exceptions import InvalidSignature + + def func_optional_call_tool_result() -> CallToolResult | None: # pragma: no cover + return CallToolResult(content=[]) + + with pytest.raises(InvalidSignature) as exc_info: + func_metadata(func_optional_call_tool_result) + + assert "Union or Optional" in str(exc_info.value) + assert "CallToolResult" in str(exc_info.value) + + +def test_tool_call_result_in_union_is_rejected(): + """Test that Union[str, CallToolResult] raises InvalidSignature""" + + from mcp.server.fastmcp.exceptions import InvalidSignature + + def func_union_call_tool_result() -> str | CallToolResult: # pragma: no cover + return CallToolResult(content=[]) + + with pytest.raises(InvalidSignature) as exc_info: + func_metadata(func_union_call_tool_result) + + assert "Union or Optional" in str(exc_info.value) + assert "CallToolResult" in str(exc_info.value) + + +def test_tool_call_result_in_pipe_union_is_rejected(): + """Test that str | CallToolResult raises InvalidSignature""" + from mcp.server.fastmcp.exceptions import InvalidSignature + + def func_pipe_union_call_tool_result() -> str | CallToolResult: # pragma: no cover + return CallToolResult(content=[]) + + with pytest.raises(InvalidSignature) as exc_info: + func_metadata(func_pipe_union_call_tool_result) + + assert "Union or Optional" in str(exc_info.value) + assert "CallToolResult" in str(exc_info.value) + + def test_structured_output_with_field_descriptions(): """Test that Field descriptions are preserved in structured output""" @@ -841,7 +929,7 @@ class ModelWithDescriptions(BaseModel): name: Annotated[str, Field(description="The person's full name")] age: Annotated[int, Field(description="Age in years", ge=0, le=150)] - def func_with_descriptions() -> ModelWithDescriptions: + def func_with_descriptions() -> ModelWithDescriptions: # pragma: no cover return ModelWithDescriptions(name="Ian", age=60) meta = func_metadata(func_with_descriptions) @@ -868,7 +956,7 @@ class PersonWithAddress(BaseModel): name: str address: Address - def func_nested() -> PersonWithAddress: + def func_nested() -> PersonWithAddress: # pragma: no cover return PersonWithAddress(name="Jack", address=Address(street="123 Main St", city="Anytown", zipcode="12345")) meta = func_metadata(func_nested) @@ -907,7 +995,7 @@ class ConfigWithCallable: # Callable defaults are not JSON serializable and will trigger Pydantic warnings callback: Callable[[Any], Any] = lambda x: x * 2 - def func_returning_config_with_callable() -> ConfigWithCallable: + def func_returning_config_with_callable() -> ConfigWithCallable: # pragma: no cover return ConfigWithCallable() # Should work without structured_output=True (returns None for output_schema) @@ -925,7 +1013,7 @@ class Point(NamedTuple): x: int y: int - def func_returning_namedtuple() -> Point: + def func_returning_namedtuple() -> Point: # pragma: no cover return Point(1, 2) # Should work without structured_output=True (returns None for output_schema) @@ -946,7 +1034,7 @@ class ModelWithAliases(BaseModel): field_first: str | None = Field(default=None, alias="first", description="The first field.") field_second: str | None = Field(default=None, alias="second", description="The second field.") - def func_with_aliases() -> ModelWithAliases: + def func_with_aliases() -> ModelWithAliases: # pragma: no cover # When aliases are defined, we must use the aliased names to set values return ModelWithAliases(**{"first": "hello", "second": "world"}) @@ -987,7 +1075,7 @@ def func_with_aliases() -> ModelWithAliases: def test_basemodel_reserved_names(): """Test that functions with parameters named after BaseModel methods work correctly""" - def func_with_reserved_names( + def func_with_reserved_names( # pragma: no cover model_dump: str, model_validate: int, dict: list[str], @@ -1015,7 +1103,7 @@ def func_with_reserved_names( async def test_basemodel_reserved_names_validation(): """Test that validation and calling works with reserved parameter names""" - def func_with_reserved_names( + def func_with_reserved_names( # pragma: no cover model_dump: str, model_validate: int, dict: list[str], @@ -1073,7 +1161,7 @@ def func_with_reserved_names( def test_basemodel_reserved_names_with_json_preparsing(): """Test that pre_parse_json works correctly with reserved parameter names""" - def func_with_reserved_json( + def func_with_reserved_json( # pragma: no cover json: dict[str, Any], model_dump: list[int], normal: str, @@ -1094,3 +1182,23 @@ def func_with_reserved_json( assert result["json"] == {"nested": "data"} assert result["model_dump"] == [1, 2, 3] assert result["normal"] == "plain string" + + +def test_disallowed_type_qualifier(): + from mcp.server.fastmcp.exceptions import InvalidSignature + + def func_disallowed_qualifier() -> Final[int]: # type: ignore + pass # pragma: no cover + + with pytest.raises(InvalidSignature) as exc_info: + func_metadata(func_disallowed_qualifier) + assert "return annotation contains an invalid type qualifier" in str(exc_info.value) + + +def test_preserves_pydantic_metadata(): + def func_with_metadata() -> Annotated[int, Field(gt=1)]: ... # pragma: no branch + + meta = func_metadata(func_with_metadata) + + assert meta.output_schema is not None + assert meta.output_schema["properties"]["result"] == {"exclusiveMinimum": 1, "title": "Result", "type": "integer"} diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index dc88cc0256..70948bd7e2 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -13,7 +13,6 @@ import json import multiprocessing import socket -import time from collections.abc import Generator import pytest @@ -35,7 +34,7 @@ ) from mcp.client.session import ClientSession from mcp.client.sse import sse_client -from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.streamable_http import GetSessionIdCallback, streamable_http_client from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -60,6 +59,7 @@ TextResourceContents, ToolListChangedNotification, ) +from tests.test_helpers import wait_for_server class NotificationCollector: @@ -75,14 +75,14 @@ async def handle_generic_notification( self, message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception ) -> None: """Handle any server notification and route to appropriate handler.""" - if isinstance(message, ServerNotification): + if isinstance(message, ServerNotification): # pragma: no branch if isinstance(message.root, ProgressNotification): self.progress_notifications.append(message.root.params) elif isinstance(message.root, LoggingMessageNotification): self.log_messages.append(message.root.params) elif isinstance(message.root, ResourceListChangedNotification): self.resource_notifications.append(message.root.params) - elif isinstance(message.root, ToolListChangedNotification): + elif isinstance(message.root, ToolListChangedNotification): # pragma: no cover self.tool_notifications.append(message.root.params) @@ -101,7 +101,7 @@ def server_url(http://github.com/server_port:%20int) -> str: return f"http://127.0.0.1:{server_port}" -def run_server_with_transport(module_name: str, port: int, transport: str) -> None: +def run_server_with_transport(module_name: str, port: int, transport: str) -> None: # pragma: no cover """Run server with specified transport.""" # Get the MCP instance based on module name if module_name == "basic_tool": @@ -160,25 +160,14 @@ def server_transport(request: pytest.FixtureRequest, server_port: int) -> Genera ) proc.start() - # Wait for server to be running - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + # Wait for server to be ready + wait_for_server(server_port) yield transport proc.kill() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover print("Server process failed to terminate") @@ -190,8 +179,8 @@ def create_client_for_transport(transport: str, server_url: str): return sse_client(endpoint) elif transport == "streamable-http": endpoint = f"{server_url}/mcp" - return streamablehttp_client(endpoint) - else: + return streamable_http_client(endpoint) + else: # pragma: no cover raise ValueError(f"Invalid transport: {transport}") @@ -244,7 +233,7 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par action="http://github.com/accept", content={"checkAlternative": True, "alternativeDate": "2024-12-26"}, ) - else: + else: # pragma: no cover return ElicitResult(action="http://github.com/decline") @@ -396,7 +385,7 @@ async def test_tool_progress(server_transport: str, server_url: str) -> None: async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message client_cm = create_client_for_transport(transport, server_url) @@ -537,7 +526,7 @@ async def test_notifications(server_transport: str, server_url: str) -> None: async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message client_cm = create_client_for_transport(transport, server_url) diff --git a/tests/server/fastmcp/test_parameter_descriptions.py b/tests/server/fastmcp/test_parameter_descriptions.py index 29470ed19c..9f2386894c 100644 --- a/tests/server/fastmcp/test_parameter_descriptions.py +++ b/tests/server/fastmcp/test_parameter_descriptions.py @@ -14,7 +14,7 @@ async def test_parameter_descriptions(): def greet( name: str = Field(description="The name to greet"), title: str = Field(description="Optional title", default=""), - ) -> str: + ) -> str: # pragma: no cover """A greeting tool""" return f"Hello {title} {name}" diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 8caa3b1f6f..3935f3bd13 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -12,6 +12,7 @@ from mcp.server.fastmcp.resources import FileResource, FunctionResource from mcp.server.fastmcp.utilities.types import Audio, Image from mcp.server.session import ServerSession +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import McpError from mcp.shared.memory import ( create_connected_server_and_client_session as client_session, @@ -147,7 +148,7 @@ async def test_add_tool_decorator(self): mcp = FastMCP() @mcp.tool() - def sum(x: int, y: int) -> int: + def sum(x: int, y: int) -> int: # pragma: no cover return x + y assert len(mcp._tool_manager.list_tools()) == 1 @@ -159,7 +160,7 @@ async def test_add_tool_decorator_incorrect_usage(self): with pytest.raises(TypeError, match="The @tool decorator was used incorrectly"): @mcp.tool # Missing parentheses #type: ignore - def sum(x: int, y: int) -> int: + def sum(x: int, y: int) -> int: # pragma: no cover return x + y @pytest.mark.anyio @@ -167,7 +168,7 @@ async def test_add_resource_decorator(self): mcp = FastMCP() @mcp.resource("r://{x}") - def get_data(x: str) -> str: + def get_data(x: str) -> str: # pragma: no cover return f"Data: {x}" assert len(mcp._resource_manager._templates) == 1 @@ -179,10 +180,56 @@ async def test_add_resource_decorator_incorrect_usage(self): with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"): @mcp.resource # Missing parentheses #type: ignore - def get_data(x: str) -> str: + def get_data(x: str) -> str: # pragma: no cover return f"Data: {x}" +class TestDnsRebindingProtection: + """Tests for automatic DNS rebinding protection on localhost.""" + + def test_auto_enabled_for_127_0_0_1(self): + """DNS rebinding protection should auto-enable for host=127.0.0.1.""" + mcp = FastMCP(host="127.0.0.1") + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + assert "http://127.0.0.1:*" in mcp.settings.transport_security.allowed_origins + assert "http://localhost:*" in mcp.settings.transport_security.allowed_origins + + def test_auto_enabled_for_localhost(self): + """DNS rebinding protection should auto-enable for host=localhost.""" + mcp = FastMCP(host="localhost") + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + assert "127.0.0.1:*" in mcp.settings.transport_security.allowed_hosts + assert "localhost:*" in mcp.settings.transport_security.allowed_hosts + + def test_auto_enabled_for_ipv6_localhost(self): + """DNS rebinding protection should auto-enable for host=::1 (IPv6 localhost).""" + mcp = FastMCP(host="::1") + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is True + assert "[::1]:*" in mcp.settings.transport_security.allowed_hosts + assert "http://[::1]:*" in mcp.settings.transport_security.allowed_origins + + def test_not_auto_enabled_for_other_hosts(self): + """DNS rebinding protection should NOT auto-enable for other hosts.""" + mcp = FastMCP(host="0.0.0.0") + assert mcp.settings.transport_security is None + + def test_explicit_settings_not_overridden(self): + """Explicit transport_security settings should not be overridden.""" + custom_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ) + mcp = FastMCP(host="127.0.0.1", transport_security=custom_settings) + # Settings are copied by pydantic, so check values not identity + assert mcp.settings.transport_security is not None + assert mcp.settings.transport_security.enable_dns_rebinding_protection is False + assert mcp.settings.transport_security.allowed_hosts == [] + + def tool_fn(x: int, y: int) -> int: return x + y @@ -756,7 +803,7 @@ async def test_function_resource(self): mcp = FastMCP() @mcp.resource("function://test", name="test_get_data") - def get_data() -> str: + def get_data() -> str: # pragma: no cover """get_data returns a string""" return "Hello, world!" @@ -780,7 +827,7 @@ async def test_resource_with_params(self): with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://data") - def get_data_fn(param: str) -> str: + def get_data_fn(param: str) -> str: # pragma: no cover return f"Data: {param}" @pytest.mark.anyio @@ -791,7 +838,7 @@ async def test_resource_with_uri_params(self): with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://{param}") - def get_data() -> str: + def get_data() -> str: # pragma: no cover return "Data" @pytest.mark.anyio @@ -800,7 +847,7 @@ async def test_resource_with_untyped_params(self): mcp = FastMCP() @mcp.resource("resource://{param}") - def get_data(param) -> str: # type: ignore + def get_data(param) -> str: # type: ignore # pragma: no cover return "Data" @pytest.mark.anyio @@ -825,7 +872,7 @@ async def test_resource_mismatched_params(self): with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://{name}/data") - def get_data(user: str) -> str: + def get_data(user: str) -> str: # pragma: no cover return f"Data for {user}" @pytest.mark.anyio @@ -850,10 +897,10 @@ async def test_resource_multiple_mismatched_params(self): with pytest.raises(ValueError, match="Mismatch between URI parameters"): @mcp.resource("resource://{org}/{repo}/data") - def get_data_mismatched(org: str, repo_2: str) -> str: + def get_data_mismatched(org: str, repo_2: str) -> str: # pragma: no cover return f"Data for {org}" - """Test that a resource with no parameters works as a regular resource""" + """Test that a resource with no parameters works as a regular resource""" # pragma: no cover mcp = FastMCP() @mcp.resource("resource://static") @@ -914,7 +961,7 @@ async def test_context_detection(self): """Test that context parameters are properly detected.""" mcp = FastMCP() - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: + def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: # pragma: no cover return f"Request {ctx.request_id}: {x}" tool = mcp._tool_manager.add_tool(tool_with_context) @@ -1223,7 +1270,7 @@ def test_prompt_decorator_error(self): with pytest.raises(TypeError, match="decorator was used incorrectly"): @mcp.prompt # type: ignore - def fn() -> str: + def fn() -> str: # pragma: no cover return "Hello, world!" @pytest.mark.anyio @@ -1232,7 +1279,7 @@ async def test_list_prompts(self): mcp = FastMCP() @mcp.prompt() - def fn(name: str, optional: str = "default") -> str: + def fn(name: str, optional: str = "default") -> str: # pragma: no cover return f"Hello, {name}!" async with client_session(mcp._mcp_server) as client: @@ -1350,7 +1397,7 @@ async def test_get_prompt_missing_args(self): mcp = FastMCP() @mcp.prompt() - def prompt_fn(name: str) -> str: + def prompt_fn(name: str) -> str: # pragma: no cover return f"Hello, {name}!" async with client_session(mcp._mcp_server) as client: diff --git a/tests/server/fastmcp/test_title.py b/tests/server/fastmcp/test_title.py index a94f6671db..7cac570123 100644 --- a/tests/server/fastmcp/test_title.py +++ b/tests/server/fastmcp/test_title.py @@ -18,23 +18,23 @@ async def test_tool_title_precedence(): # Tool with only name @mcp.tool(description="Basic tool") - def basic_tool(message: str) -> str: + def basic_tool(message: str) -> str: # pragma: no cover return message # Tool with title @mcp.tool(description="Tool with title", title="User-Friendly Tool") - def tool_with_title(message: str) -> str: + def tool_with_title(message: str) -> str: # pragma: no cover return message # Tool with annotations.title (when title is not supported on decorator) # We'll need to add this manually after registration @mcp.tool(description="Tool with annotations") - def tool_with_annotations(message: str) -> str: + def tool_with_annotations(message: str) -> str: # pragma: no cover return message # Tool with both title and annotations.title @mcp.tool(description="Tool with both", title="Primary Title") - def tool_with_both(message: str) -> str: + def tool_with_both(message: str) -> str: # pragma: no cover return message # Start server and connect client @@ -73,12 +73,12 @@ async def test_prompt_title(): # Prompt with only name @mcp.prompt(description="Basic prompt") - def basic_prompt(topic: str) -> str: + def basic_prompt(topic: str) -> str: # pragma: no cover return f"Tell me about {topic}" # Prompt with title @mcp.prompt(description="Titled prompt", title="Ask About Topic") - def titled_prompt(topic: str) -> str: + def titled_prompt(topic: str) -> str: # pragma: no cover return f"Tell me about {topic}" # Start server and connect client @@ -107,7 +107,7 @@ async def test_resource_title(): mcp = FastMCP(name="ResourceTitleServer") # Static resource without title - def get_basic_data() -> str: + def get_basic_data() -> str: # pragma: no cover return "Basic data" basic_resource = FunctionResource( @@ -119,7 +119,7 @@ def get_basic_data() -> str: mcp.add_resource(basic_resource) # Static resource with title - def get_titled_data() -> str: + def get_titled_data() -> str: # pragma: no cover return "Titled data" titled_resource = FunctionResource( @@ -133,12 +133,12 @@ def get_titled_data() -> str: # Dynamic resource without title @mcp.resource("resource://dynamic/{id}") - def dynamic_resource(id: str) -> str: + def dynamic_resource(id: str) -> str: # pragma: no cover return f"Data for {id}" # Dynamic resource with title (when supported) @mcp.resource("resource://titled-dynamic/{id}", title="Dynamic Data") - def titled_dynamic_resource(id: str) -> str: + def titled_dynamic_resource(id: str) -> str: # pragma: no cover return f"Data for {id}" # Start server and connect client @@ -171,7 +171,7 @@ def titled_dynamic_resource(id: str) -> str: assert dynamic.name == "dynamic_resource" # Verify titled dynamic resource template (when supported) - if "resource://titled-dynamic/{id}" in templates: + if "resource://titled-dynamic/{id}" in templates: # pragma: no branch titled_dynamic = templates["resource://titled-dynamic/{id}"] assert titled_dynamic.title == "Dynamic Data" diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 71884fba22..d83d484744 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -19,7 +19,7 @@ class TestAddTools: def test_basic_function(self): """Test registering and running a basic function.""" - def sum(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b @@ -35,7 +35,7 @@ def sum(a: int, b: int) -> int: assert tool.parameters["properties"]["b"]["type"] == "integer" def test_init_with_tools(self, caplog: pytest.LogCaptureFixture): - def sum(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: # pragma: no cover return a + b class AddArguments(ArgModelBase): @@ -68,7 +68,7 @@ class AddArguments(ArgModelBase): async def test_async_function(self): """Test registering and running an async function.""" - async def fetch_data(url: str) -> str: + async def fetch_data(url: str) -> str: # pragma: no cover """Fetch data from URL.""" return f"Data from {url}" @@ -89,7 +89,7 @@ class UserInput(BaseModel): name: str age: int - def create_user(user: UserInput, flag: bool) -> dict[str, Any]: + def create_user(user: UserInput, flag: bool) -> dict[str, Any]: # pragma: no cover """Create a new user.""" return {"id": 1, **user.model_dump()} @@ -112,7 +112,7 @@ class MyTool: def __init__(self): self.__name__ = "MyTool" - def __call__(self, x: int) -> int: + def __call__(self, x: int) -> int: # pragma: no cover return x * 2 manager = ToolManager() @@ -129,7 +129,7 @@ class MyAsyncTool: def __init__(self): self.__name__ = "MyAsyncTool" - async def __call__(self, x: int) -> int: + async def __call__(self, x: int) -> int: # pragma: no cover return x * 2 manager = ToolManager() @@ -156,7 +156,7 @@ def test_add_lambda_with_no_name(self): def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test warning on duplicate tools.""" - def f(x: int) -> int: + def f(x: int) -> int: # pragma: no cover return x manager = ToolManager() @@ -168,7 +168,7 @@ def f(x: int) -> int: def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate tools.""" - def f(x: int) -> int: + def f(x: int) -> int: # pragma: no cover return x manager = ToolManager() @@ -182,7 +182,7 @@ def f(x: int) -> int: class TestCallTools: @pytest.mark.anyio async def test_call_tool(self): - def sum(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b @@ -193,7 +193,7 @@ def sum(a: int, b: int) -> int: @pytest.mark.anyio async def test_call_async_tool(self): - async def double(n: int) -> int: + async def double(n: int) -> int: # pragma: no cover """Double a number.""" return n * 2 @@ -243,7 +243,7 @@ def sum(a: int, b: int = 1) -> int: @pytest.mark.anyio async def test_call_tool_with_missing_args(self): - def sum(a: int, b: int) -> int: + def sum(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b @@ -260,7 +260,7 @@ async def test_call_unknown_tool(self): @pytest.mark.anyio async def test_call_tool_with_list_int_input(self): - def sum_vals(vals: list[int]) -> int: + def sum_vals(vals: list[int]) -> int: # pragma: no cover return sum(vals) manager = ToolManager() @@ -273,7 +273,7 @@ def sum_vals(vals: list[int]) -> int: @pytest.mark.anyio async def test_call_tool_with_list_str_or_str_input(self): - def concat_strs(vals: list[str] | str) -> str: + def concat_strs(vals: list[str] | str) -> str: # pragma: no cover return vals if isinstance(vals, str) else "".join(vals) manager = ToolManager() @@ -297,7 +297,7 @@ class Shrimp(BaseModel): shrimp: list[Shrimp] x: None - def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[str]: + def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[str]: # pragma: no cover return [x.name for x in tank.shrimp] manager = ToolManager() @@ -317,7 +317,7 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[ class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context[ServerSessionT, None]) -> int: + def something(a: int, ctx: Context[ServerSessionT, None]) -> int: # pragma: no cover return a manager = ToolManager() @@ -334,20 +334,22 @@ def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: + def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: # pragma: no cover return str(x) manager = ToolManager() tool = manager.add_tool(tool_with_context) assert tool.context_kwarg == "ctx" - def tool_without_context(x: int) -> str: + def tool_without_context(x: int) -> str: # pragma: no cover return str(x) tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None - def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str: + def tool_with_parametrized_context( + x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT] + ) -> str: # pragma: no cover return str(x) tool = manager.add_tool(tool_with_parametrized_context) @@ -373,7 +375,7 @@ def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: + async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: # pragma: no cover assert isinstance(ctx, Context) return str(x) @@ -418,7 +420,7 @@ class TestToolAnnotations: def test_tool_annotations(self): """Test that tool annotations are correctly added to tools.""" - def read_data(path: str) -> str: + def read_data(path: str) -> str: # pragma: no cover """Read data from a file.""" return f"Data from {path}" @@ -443,7 +445,7 @@ async def test_tool_annotations_in_fastmcp(self): app = FastMCP() @app.tool(annotations=ToolAnnotations(title="Echo Tool", readOnlyHint=True)) - def echo(message: str) -> str: + def echo(message: str) -> str: # pragma: no cover """Echo a message back.""" return message @@ -465,7 +467,7 @@ class UserOutput(BaseModel): name: str age: int - def get_user(user_id: int) -> UserOutput: + def get_user(user_id: int) -> UserOutput: # pragma: no cover """Get user by ID.""" return UserOutput(name="John", age=30) @@ -479,7 +481,7 @@ def get_user(user_id: int) -> UserOutput: async def test_tool_with_primitive_output(self): """Test tool with primitive return type.""" - def double_number(n: int) -> int: + def double_number(n: int) -> int: # pragma: no cover """Double a number.""" return 10 @@ -500,7 +502,7 @@ class UserDict(TypedDict): expected_output = {"name": "Alice", "age": 25} - def get_user_dict(user_id: int) -> UserDict: + def get_user_dict(user_id: int) -> UserDict: # pragma: no cover """Get user as dict.""" return UserDict(name="Alice", age=25) @@ -520,7 +522,7 @@ class Person: expected_output = {"name": "Bob", "age": 40} - def get_person() -> Person: + def get_person() -> Person: # pragma: no cover """Get a person.""" return Person("Bob", 40) @@ -537,7 +539,7 @@ async def test_tool_with_list_output(self): expected_list = [1, 2, 3, 4, 5] expected_output = {"result": expected_list} - def get_numbers() -> list[int]: + def get_numbers() -> list[int]: # pragma: no cover """Get a list of numbers.""" return expected_list @@ -569,7 +571,7 @@ class UserOutput(BaseModel): name: str age: int - def get_user() -> UserOutput: + def get_user() -> UserOutput: # pragma: no cover return UserOutput(name="Test", age=25) manager = ToolManager() @@ -588,7 +590,7 @@ def get_user() -> UserOutput: async def test_tool_with_dict_str_any_output(self): """Test tool with dict[str, Any] return type.""" - def get_config() -> dict[str, Any]: + def get_config() -> dict[str, Any]: # pragma: no cover """Get configuration""" return {"debug": True, "port": 8080, "features": ["auth", "logging"]} @@ -613,7 +615,7 @@ def get_config() -> dict[str, Any]: async def test_tool_with_dict_str_typed_output(self): """Test tool with dict[str, T] return type for specific T.""" - def get_scores() -> dict[str, int]: + def get_scores() -> dict[str, int]: # pragma: no cover """Get player scores""" return {"alice": 100, "bob": 85, "charlie": 92} @@ -635,13 +637,185 @@ def get_scores() -> dict[str, int]: assert result == expected +class TestToolMetadata: + """Test tool metadata functionality.""" + + def test_add_tool_with_metadata(self): + """Test adding a tool with metadata via ToolManager.""" + + def process_data(input_data: str) -> str: # pragma: no cover + """Process some data.""" + return f"Processed: {input_data}" + + metadata = {"ui": {"type": "form", "fields": ["input"]}, "version": "1.0"} + + manager = ToolManager() + tool = manager.add_tool(process_data, meta=metadata) + + assert tool.meta is not None + assert tool.meta == metadata + assert tool.meta["ui"]["type"] == "form" + assert tool.meta["version"] == "1.0" + + def test_add_tool_without_metadata(self): + """Test that tools without metadata have None as meta value.""" + + def simple_tool(x: int) -> int: # pragma: no cover + """Simple tool.""" + return x * 2 + + manager = ToolManager() + tool = manager.add_tool(simple_tool) + + assert tool.meta is None + + @pytest.mark.anyio + async def test_metadata_in_fastmcp_decorator(self): + """Test that metadata is correctly added via FastMCP.tool decorator.""" + + app = FastMCP() + + metadata = {"client": {"ui_component": "file_picker"}, "priority": "high"} + + @app.tool(meta=metadata) + def upload_file(filename: str) -> str: # pragma: no cover + """Upload a file.""" + return f"Uploaded: {filename}" + + # Get the tool from the tool manager + tool = app._tool_manager.get_tool("upload_file") + assert tool is not None + assert tool.meta is not None + assert tool.meta == metadata + assert tool.meta["client"]["ui_component"] == "file_picker" + assert tool.meta["priority"] == "high" + + @pytest.mark.anyio + async def test_metadata_in_list_tools(self): + """Test that metadata is included in MCPTool when listing tools.""" + + app = FastMCP() + + metadata = { + "ui": {"input_type": "textarea", "rows": 5}, + "tags": ["text", "processing"], + } + + @app.tool(meta=metadata) + def analyze_text(text: str) -> dict[str, Any]: # pragma: no cover + """Analyze text content.""" + return {"length": len(text), "words": len(text.split())} + + tools = await app.list_tools() + assert len(tools) == 1 + assert tools[0].meta is not None + assert tools[0].meta == metadata + + @pytest.mark.anyio + async def test_multiple_tools_with_different_metadata(self): + """Test multiple tools with different metadata values.""" + + app = FastMCP() + + metadata1 = {"ui": "form", "version": 1} + metadata2 = {"ui": "picker", "experimental": True} + + @app.tool(meta=metadata1) + def tool1(x: int) -> int: # pragma: no cover + """First tool.""" + return x + + @app.tool(meta=metadata2) + def tool2(y: str) -> str: # pragma: no cover + """Second tool.""" + return y + + @app.tool() + def tool3(z: bool) -> bool: # pragma: no cover + """Third tool without metadata.""" + return z + + tools = await app.list_tools() + assert len(tools) == 3 + + # Find tools by name and check metadata + tools_by_name = {t.name: t for t in tools} + + assert tools_by_name["tool1"].meta == metadata1 + assert tools_by_name["tool2"].meta == metadata2 + assert tools_by_name["tool3"].meta is None + + def test_metadata_with_complex_structure(self): + """Test metadata with complex nested structures.""" + + def complex_tool(data: str) -> str: # pragma: no cover + """Tool with complex metadata.""" + return data + + metadata = { + "ui": { + "components": [ + {"type": "input", "name": "field1", "validation": {"required": True, "minLength": 5}}, + {"type": "select", "name": "field2", "options": ["a", "b", "c"]}, + ], + "layout": {"columns": 2, "responsive": True}, + }, + "permissions": ["read", "write"], + "tags": ["data-processing", "user-input"], + "version": 2, + } + + manager = ToolManager() + tool = manager.add_tool(complex_tool, meta=metadata) + + assert tool.meta is not None + assert tool.meta["ui"]["components"][0]["validation"]["minLength"] == 5 + assert tool.meta["ui"]["layout"]["columns"] == 2 + assert "read" in tool.meta["permissions"] + assert "data-processing" in tool.meta["tags"] + + def test_metadata_empty_dict(self): + """Test that empty dict metadata is preserved.""" + + def tool_with_empty_meta(x: int) -> int: # pragma: no cover + """Tool with empty metadata.""" + return x + + manager = ToolManager() + tool = manager.add_tool(tool_with_empty_meta, meta={}) + + assert tool.meta is not None + assert tool.meta == {} + + @pytest.mark.anyio + async def test_metadata_with_annotations(self): + """Test that metadata and annotations can coexist.""" + + app = FastMCP() + + metadata = {"custom": "value"} + annotations = ToolAnnotations(title="Combined Tool", readOnlyHint=True) + + @app.tool(meta=metadata, annotations=annotations) + def combined_tool(data: str) -> str: # pragma: no cover + """Tool with both metadata and annotations.""" + return data + + tools = await app.list_tools() + assert len(tools) == 1 + assert tools[0].meta == metadata + assert tools[0].annotations is not None + assert tools[0].annotations.title == "Combined Tool" + assert tools[0].annotations.readOnlyHint is True + + class TestRemoveTools: """Test tool removal functionality in the tool manager.""" def test_remove_existing_tool(self): """Test removing an existing tool.""" - def add(a: int, b: int) -> int: + def add(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b @@ -669,15 +843,15 @@ def test_remove_nonexistent_tool(self): def test_remove_tool_from_multiple_tools(self): """Test removing one tool when multiple tools exist.""" - def add(a: int, b: int) -> int: + def add(a: int, b: int) -> int: # pragma: no cover """Add two numbers.""" return a + b - def multiply(a: int, b: int) -> int: + def multiply(a: int, b: int) -> int: # pragma: no cover """Multiply two numbers.""" return a * b - def divide(a: int, b: int) -> float: + def divide(a: int, b: int) -> float: # pragma: no cover """Divide two numbers.""" return a / b @@ -705,7 +879,7 @@ def divide(a: int, b: int) -> float: async def test_call_removed_tool_raises_error(self): """Test that calling a removed tool raises ToolError.""" - def greet(name: str) -> str: + def greet(name: str) -> str: # pragma: no cover """Greet someone.""" return f"Hello, {name}!" @@ -726,7 +900,7 @@ def greet(name: str) -> str: def test_remove_tool_case_sensitive(self): """Test that tool removal is case-sensitive.""" - def test_func() -> str: + def test_func() -> str: # pragma: no cover """Test function.""" return "test" diff --git a/tests/server/fastmcp/test_url_elicitation.py b/tests/server/fastmcp/test_url_elicitation.py new file mode 100644 index 0000000000..a4d3b2e643 --- /dev/null +++ b/tests/server/fastmcp/test_url_elicitation.py @@ -0,0 +1,394 @@ +"""Test URL mode elicitation feature (SEP 1036).""" + +import anyio +import pytest + +from mcp import types +from mcp.client.session import ClientSession +from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitRequestParams, ElicitResult, TextContent + + +@pytest.mark.anyio +async def test_url_elicitation_accept(): + """Test URL mode elicitation with user acceptance.""" + mcp = FastMCP(name="URLElicitationServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def request_api_key(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Please provide your API key to continue.", + url="https://example.com/api_key_setup", + elicitation_id="test-elicitation-001", + ) + # Test only checks accept path + return f"User {result.action}" + + # Create elicitation callback that accepts URL mode + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + assert params.url == "https://example.com/api_key_setup" + assert params.elicitationId == "test-elicitation-001" + assert params.message == "Please provide your API key to continue." + return ElicitResult(action="http://github.com/accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("request_api_key", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User accept" + + +@pytest.mark.anyio +async def test_url_elicitation_decline(): + """Test URL mode elicitation with user declining.""" + mcp = FastMCP(name="URLElicitationDeclineServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def oauth_flow(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Authorize access to your files.", + url="https://example.com/oauth/authorize", + elicitation_id="oauth-001", + ) + # Test only checks decline path + return f"User {result.action} authorization" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + return ElicitResult(action="http://github.com/decline") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("oauth_flow", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User decline authorization" + + +@pytest.mark.anyio +async def test_url_elicitation_cancel(): + """Test URL mode elicitation with user cancelling.""" + mcp = FastMCP(name="URLElicitationCancelServer") + + @mcp.tool(description="A tool that uses URL elicitation") + async def payment_flow(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Complete payment to proceed.", + url="https://example.com/payment", + elicitation_id="payment-001", + ) + # Test only checks cancel path + return f"User {result.action} payment" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + return ElicitResult(action="http://github.com/cancel") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("payment_flow", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "User cancel payment" + + +@pytest.mark.anyio +async def test_url_elicitation_helper_function(): + """Test the elicit_url helper function.""" + from mcp.server.elicitation import elicit_url + + mcp = FastMCP(name="URLElicitationHelperServer") + + @mcp.tool(description="Tool using elicit_url helper") + async def setup_credentials(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Set up your credentials", + url="https://example.com/setup", + elicitation_id="setup-001", + ) + # Test only checks accept path - return the type name + return type(result).__name__ + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="http://github.com/accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("setup_credentials", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "AcceptedUrlElicitation" + + +@pytest.mark.anyio +async def test_url_no_content_in_response(): + """Test that URL mode elicitation responses don't include content field.""" + mcp = FastMCP(name="URLContentCheckServer") + + @mcp.tool(description="Check URL response format") + async def check_url_response(ctx: Context[ServerSession, None]) -> str: + result = await ctx.session.elicit_url( + message="Test message", + url="https://example.com/test", + elicitation_id="test-001", + ) + + # URL mode responses should not have content + assert result.content is None + return f"Action: {result.action}, Content: {result.content}" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify that this is URL mode + assert params.mode == "url" + assert isinstance(params, types.ElicitRequestURLParams) + # URL params have url and elicitationId, not requestedSchema + assert params.url == "https://example.com/test" + assert params.elicitationId == "test-001" + # Return without content - this is correct for URL mode + return ElicitResult(action="http://github.com/accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("check_url_response", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Content: None" in result.content[0].text + + +@pytest.mark.anyio +async def test_form_mode_still_works(): + """Ensure form mode elicitation still works after SEP 1036.""" + from pydantic import BaseModel, Field + + mcp = FastMCP(name="FormModeBackwardCompatServer") + + class NameSchema(BaseModel): + name: str = Field(description="Your name") + + @mcp.tool(description="Test form mode") + async def ask_name(ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message="What is your name?", schema=NameSchema) + # Test only checks accept path with data + assert result.action == "accept" + assert result.data is not None + return f"Hello, {result.data.name}!" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify form mode parameters + assert params.mode == "form" + assert isinstance(params, types.ElicitRequestFormParams) + # Form params have requestedSchema, not url/elicitationId + assert params.requestedSchema is not None + return ElicitResult(action="http://github.com/accept", content={"name": "Alice"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("ask_name", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Hello, Alice!" + + +@pytest.mark.anyio +async def test_elicit_complete_notification(): + """Test that elicitation completion notifications can be sent and received.""" + mcp = FastMCP(name="ElicitCompleteServer") + + # Track if the notification was sent + notification_sent = False + + @mcp.tool(description="Tool that sends completion notification") + async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: + nonlocal notification_sent + + # Simulate an async operation (e.g., user completing auth in browser) + elicitation_id = "complete-test-001" + + # Send completion notification + await ctx.session.send_elicit_complete(elicitation_id) + notification_sent = True + + return "Elicitation completed" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="http://github.com/accept") # pragma: no cover + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("trigger_elicitation", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Elicitation completed" + + # Give time for notification to be processed + await anyio.sleep(0.1) + + # Verify the notification was sent + assert notification_sent + + +@pytest.mark.anyio +async def test_url_elicitation_required_error_code(): + """Test that the URL_ELICITATION_REQUIRED error code is correct.""" + # Verify the error code matches the specification (SEP 1036) + assert types.URL_ELICITATION_REQUIRED == -32042, ( + "URL_ELICITATION_REQUIRED error code must be -32042 per SEP 1036 specification" + ) + + +@pytest.mark.anyio +async def test_elicit_url_typed_results(): + """Test that elicit_url returns properly typed result objects.""" + from mcp.server.elicitation import elicit_url + + mcp = FastMCP(name="TypedResultsServer") + + @mcp.tool(description="Test declined result") + async def test_decline(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Test decline", + url="https://example.com/decline", + elicitation_id="decline-001", + ) + + if isinstance(result, DeclinedElicitation): + return "Declined" + return "Not declined" # pragma: no cover + + @mcp.tool(description="Test cancelled result") + async def test_cancel(ctx: Context[ServerSession, None]) -> str: + result = await elicit_url( + session=ctx.session, + message="Test cancel", + url="https://example.com/cancel", + elicitation_id="cancel-001", + ) + + if isinstance(result, CancelledElicitation): + return "Cancelled" + return "Not cancelled" # pragma: no cover + + # Test declined result + async def decline_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="http://github.com/decline") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=decline_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("test_decline", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Declined" + + # Test cancelled result + async def cancel_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="http://github.com/cancel") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=cancel_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("test_cancel", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Cancelled" + + +@pytest.mark.anyio +async def test_deprecated_elicit_method(): + """Test the deprecated elicit() method for backward compatibility.""" + from pydantic import BaseModel, Field + + mcp = FastMCP(name="DeprecatedElicitServer") + + class EmailSchema(BaseModel): + email: str = Field(description="Email address") + + @mcp.tool(description="Test deprecated elicit method") + async def use_deprecated_elicit(ctx: Context[ServerSession, None]) -> str: + # Use the deprecated elicit() method which should call elicit_form() + result = await ctx.session.elicit( + message="Enter your email", + requestedSchema=EmailSchema.model_json_schema(), + ) + + if result.action == "accept" and result.content: + return f"Email: {result.content.get('email', 'none')}" + return "No email provided" # pragma: no cover + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Verify this is form mode + assert params.mode == "form" + assert params.requestedSchema is not None + return ElicitResult(action="http://github.com/accept", content={"email": "test@example.com"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("use_deprecated_elicit", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Email: test@example.com" + + +@pytest.mark.anyio +async def test_ctx_elicit_url_convenience_method(): + """Test the ctx.elicit_url() convenience method (vs ctx.session.elicit_url()).""" + mcp = FastMCP(name="CtxElicitUrlServer") + + @mcp.tool(description="A tool that uses ctx.elicit_url() directly") + async def direct_elicit_url(/service/ctx: Context[ServerSession, None]) -> str: + # Use ctx.elicit_url() directly instead of ctx.session.elicit_url() + result = await ctx.elicit_url( + message="Test the convenience method", + url="https://example.com/test", + elicitation_id="ctx-test-001", + ) + return f"Result: {result.action}" + + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + assert params.mode == "url" + assert params.elicitationId == "ctx-test-001" + return ElicitResult(action="http://github.com/accept") + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + result = await client_session.call_tool("direct_elicit_url", {}) + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Result: accept" diff --git a/tests/server/fastmcp/test_url_elicitation_error_throw.py b/tests/server/fastmcp/test_url_elicitation_error_throw.py new file mode 100644 index 0000000000..2d7eda4ab4 --- /dev/null +++ b/tests/server/fastmcp/test_url_elicitation_error_throw.py @@ -0,0 +1,113 @@ +"""Test that UrlElicitationRequiredError is properly propagated as MCP error.""" + +import pytest + +from mcp import types +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession +from mcp.shared.exceptions import McpError, UrlElicitationRequiredError +from mcp.shared.memory import create_connected_server_and_client_session + + +@pytest.mark.anyio +async def test_url_elicitation_error_thrown_from_tool(): + """Test that UrlElicitationRequiredError raised from a tool is received as McpError by client.""" + mcp = FastMCP(name="UrlElicitationErrorServer") + + @mcp.tool(description="A tool that raises UrlElicitationRequiredError") + async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str: + # This tool cannot proceed without authorization + raise UrlElicitationRequiredError( + [ + types.ElicitRequestURLParams( + mode="url", + message=f"Authorization required to connect to {service_name}", + url=f"https://{service_name}.example.com/oauth/authorize", + elicitationId=f"{service_name}-auth-001", + ) + ] + ) + + async with create_connected_server_and_client_session(mcp._mcp_server) as client_session: + await client_session.initialize() + + # Call the tool - it should raise McpError with URL_ELICITATION_REQUIRED code + with pytest.raises(McpError) as exc_info: + await client_session.call_tool("connect_service", {"service_name": "github"}) + + # Verify the error details + error = exc_info.value.error + assert error.code == types.URL_ELICITATION_REQUIRED + assert error.message == "URL elicitation required" + + # Verify the error data contains elicitations + assert error.data is not None + assert "elicitations" in error.data + elicitations = error.data["elicitations"] + assert len(elicitations) == 1 + assert elicitations[0]["mode"] == "url" + assert elicitations[0]["url"] == "https://github.example.com/oauth/authorize" + assert elicitations[0]["elicitationId"] == "github-auth-001" + + +@pytest.mark.anyio +async def test_url_elicitation_error_from_error(): + """Test that client can reconstruct UrlElicitationRequiredError from McpError.""" + mcp = FastMCP(name="UrlElicitationErrorServer") + + @mcp.tool(description="A tool that raises UrlElicitationRequiredError with multiple elicitations") + async def multi_auth(ctx: Context[ServerSession, None]) -> str: + raise UrlElicitationRequiredError( + [ + types.ElicitRequestURLParams( + mode="url", + message="GitHub authorization required", + url="https://github.example.com/oauth", + elicitationId="github-auth", + ), + types.ElicitRequestURLParams( + mode="url", + message="Google Drive authorization required", + url="https://drive.google.com/oauth", + elicitationId="gdrive-auth", + ), + ] + ) + + async with create_connected_server_and_client_session(mcp._mcp_server) as client_session: + await client_session.initialize() + + # Call the tool and catch the error + with pytest.raises(McpError) as exc_info: + await client_session.call_tool("multi_auth", {}) + + # Reconstruct the typed error + mcp_error = exc_info.value + assert mcp_error.error.code == types.URL_ELICITATION_REQUIRED + + url_error = UrlElicitationRequiredError.from_error(mcp_error.error) + + # Verify the reconstructed error has both elicitations + assert len(url_error.elicitations) == 2 + assert url_error.elicitations[0].elicitationId == "github-auth" + assert url_error.elicitations[1].elicitationId == "gdrive-auth" + + +@pytest.mark.anyio +async def test_normal_exceptions_still_return_error_result(): + """Test that normal exceptions still return CallToolResult with isError=True.""" + mcp = FastMCP(name="NormalErrorServer") + + @mcp.tool(description="A tool that raises a normal exception") + async def failing_tool(ctx: Context[ServerSession, None]) -> str: + raise ValueError("Something went wrong") + + async with create_connected_server_and_client_session(mcp._mcp_server) as client_session: + await client_session.initialize() + + # Normal exceptions should be returned as error results, not McpError + result = await client_session.call_tool("failing_tool", {}) + assert result.isError is True + assert len(result.content) == 1 + assert isinstance(result.content[0], types.TextContent) + assert "Something went wrong" in result.content[0].text diff --git a/tests/server/lowlevel/test_func_inspection.py b/tests/server/lowlevel/test_func_inspection.py index 556fede4aa..9cb2b561ac 100644 --- a/tests/server/lowlevel/test_func_inspection.py +++ b/tests/server/lowlevel/test_func_inspection.py @@ -121,7 +121,7 @@ async def test_untyped_request_param_is_deprecated() -> None: """Test: def foo(req) - should call without request.""" called = False - async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] + async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[reportMissingParameterType] # pragma: no cover nonlocal called called = True return ["test"] @@ -139,7 +139,7 @@ async def handler(req): # type: ignore[no-untyped-def] # pyright: ignore[repor async def test_any_typed_request_param_is_deprecated() -> None: """Test: def foo(req: Any) - should call without request.""" - async def handler(req: Any) -> list[str]: + async def handler(req: Any) -> list[str]: # pragma: no cover return ["test"] wrapper = create_call_wrapper(handler, ListPromptsRequest) @@ -155,7 +155,7 @@ async def handler(req: Any) -> list[str]: async def test_generic_typed_request_param_is_deprecated() -> None: """Test: def foo(req: Generic[T]) - should call without request.""" - async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] + async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGeneralTypeIssues] # pragma: no cover return ["test"] wrapper = create_call_wrapper(handler, ListPromptsRequest) @@ -171,7 +171,7 @@ async def handler(req: Generic[T]) -> list[str]: # pyright: ignore[reportGenera async def test_wrong_typed_request_param_is_deprecated() -> None: """Test: def foo(req: str) - should call without request.""" - async def handler(req: str) -> list[str]: + async def handler(req: str) -> list[str]: # pragma: no cover return ["test"] wrapper = create_call_wrapper(handler, ListPromptsRequest) @@ -188,7 +188,7 @@ async def test_required_param_before_typed_request_attempts_to_pass() -> None: """Test: def foo(thing: int, req: ListPromptsRequest) - attempts to pass request (will fail at runtime).""" received_request = None - async def handler(thing: int, req: ListPromptsRequest) -> list[str]: + async def handler(thing: int, req: ListPromptsRequest) -> list[str]: # pragma: no cover nonlocal received_request received_request = req return ["test"] @@ -280,7 +280,7 @@ async def handler2(req: ListToolsRequest) -> list[str]: async def test_mixed_params_with_typed_request() -> None: """Test: def foo(a: str, req: ListPromptsRequest, b: int = 5) - attempts to pass request.""" - async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: + async def handler(a: str, req: ListPromptsRequest, b: int = 5) -> list[str]: # pragma: no cover return ["test"] wrapper = create_call_wrapper(handler, ListPromptsRequest) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 516642c4b0..47c49bb62b 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -52,7 +52,7 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ ev_first_call.set() await anyio.sleep(5) # First call is slow return [types.TextContent(type="text", text=f"Call number: {call_count}")] - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {name}") # pragma: no cover async with create_connected_server_and_client_session(server) as client: # First request (will be cancelled) @@ -66,7 +66,7 @@ async def first_request(): ), CallToolResult, ) - pytest.fail("First request should have been cancelled") + pytest.fail("First request should have been cancelled") # pragma: no cover except McpError: pass # Expected diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index f0864667dc..eb9604791a 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -104,10 +104,10 @@ async def handle_completion( db = context.arguments.get("database") if db == "users_db": return Completion(values=["users", "sessions", "permissions"], total=3, hasMore=False) - elif db == "products_db": + elif db == "products_db": # pragma: no cover return Completion(values=["products", "categories", "inventory"], total=3, hasMore=False) - return Completion(values=[], total=0, hasMore=False) + return Completion(values=[], total=0, hasMore=False) # pragma: no cover async with create_connected_server_and_client_session(server) as client: # First, complete database @@ -155,10 +155,10 @@ async def handle_completion( raise ValueError("Please select a database first to see available tables") # Normal completion if context is provided db = context.arguments.get("database") - if db == "test_db": + if db == "test_db": # pragma: no cover return Completion(values=["users", "orders", "products"], total=3, hasMore=False) - return Completion(values=[], total=0, hasMore=False) + return Completion(values=[], total=0, hasMore=False) # pragma: no cover async with create_connected_server_and_client_session(server) as client: # Try to complete table without database context - should raise error diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py new file mode 100644 index 0000000000..5d4c3347f6 --- /dev/null +++ b/tests/server/test_lowlevel_exception_handling.py @@ -0,0 +1,74 @@ +from unittest.mock import AsyncMock, Mock + +import pytest + +import mcp.types as types +from mcp.server.lowlevel.server import Server +from mcp.server.session import ServerSession +from mcp.shared.session import RequestResponder + + +@pytest.mark.anyio +async def test_exception_handling_with_raise_exceptions_true(): + """Test that exceptions are re-raised when raise_exceptions=True""" + server = Server("test-server") + session = Mock(spec=ServerSession) + session.send_log_message = AsyncMock() + + test_exception = RuntimeError("Test error") + + with pytest.raises(RuntimeError, match="Test error"): + await server._handle_message(test_exception, session, {}, raise_exceptions=True) + + session.send_log_message.assert_called_once() + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "exception_class,message", + [ + (ValueError, "Test validation error"), + (RuntimeError, "Test runtime error"), + (KeyError, "Test key error"), + (Exception, "Basic error"), + ], +) +async def test_exception_handling_with_raise_exceptions_false(exception_class: type[Exception], message: str): + """Test that exceptions are logged when raise_exceptions=False""" + server = Server("test-server") + session = Mock(spec=ServerSession) + session.send_log_message = AsyncMock() + + test_exception = exception_class(message) + + await server._handle_message(test_exception, session, {}, raise_exceptions=False) + + # Should send log message + session.send_log_message.assert_called_once() + call_args = session.send_log_message.call_args + + assert call_args.kwargs["level"] == "error" + assert call_args.kwargs["data"] == "Internal Server Error" + assert call_args.kwargs["logger"] == "mcp.server.exception_handler" + + +@pytest.mark.anyio +async def test_normal_message_handling_not_affected(): + """Test that normal messages still work correctly""" + server = Server("test-server") + session = Mock(spec=ServerSession) + + # Create a mock RequestResponder + responder = Mock(spec=RequestResponder) + responder.request = types.ClientRequest(root=types.PingRequest(method="ping")) + responder.__enter__ = Mock(return_value=responder) + responder.__exit__ = Mock(return_value=None) + + # Mock the _handle_request method to avoid complex setup + server._handle_request = AsyncMock() + + # Should handle normally without any exception handling + await server._handle_message(responder, session, {}, raise_exceptions=False) + + # Verify _handle_request was called + server._handle_request.assert_called_once() diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py index 8de5494a81..47cb57232d 100644 --- a/tests/server/test_lowlevel_input_validation.py +++ b/tests/server/test_lowlevel_input_validation.py @@ -50,7 +50,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: async def message_handler( message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message # Server task @@ -122,7 +122,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextCo if name == "add": result = arguments["a"] + arguments["b"] return [TextContent(type="text", text=f"Result: {result}")] - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -143,7 +143,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: async def test_invalid_tool_call_missing_required(): """Test that missing required arguments fail validation.""" - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover # This should not be reached due to validation raise RuntimeError("Should not reach here") @@ -166,7 +166,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: async def test_invalid_tool_call_wrong_type(): """Test that wrong argument types fail validation.""" - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover # This should not be reached due to validation raise RuntimeError("Should not reach here") @@ -207,7 +207,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextCo if name == "multiply": result = arguments["x"] * arguments["y"] return [TextContent(type="text", text=f"Result: {result}")] - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -244,7 +244,7 @@ async def test_enum_constraint_validation(): ) ] - async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: # pragma: no cover # This should not be reached due to validation failure raise RuntimeError("Should not reach here") @@ -286,7 +286,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextCo if name == "unknown_tool": # Even with invalid arguments, this should execute since validation is skipped return [TextContent(type="text", text="Unknown tool executed without validation")] - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: diff --git a/tests/server/test_lowlevel_output_validation.py b/tests/server/test_lowlevel_output_validation.py index 7bcdf59d3d..f735445212 100644 --- a/tests/server/test_lowlevel_output_validation.py +++ b/tests/server/test_lowlevel_output_validation.py @@ -48,7 +48,7 @@ async def call_tool(name: str, arguments: dict[str, Any]): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) # Message handler for client - async def message_handler( + async def message_handler( # pragma: no cover message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -119,7 +119,7 @@ async def test_content_only_without_output_schema(): async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: if name == "echo": return [TextContent(type="text", text=f"Echo: {arguments['message']}")] - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -155,7 +155,7 @@ async def test_dict_only_without_output_schema(): async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, Any]: if name == "get_info": return {"status": "ok", "data": {"value": 42}} - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -194,7 +194,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[ content = [TextContent(type="text", text="Processing complete")] data = {"result": "success", "count": 10} return (content, data) - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -282,7 +282,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, A x = arguments["x"] y = arguments["y"] return {"sum": x + y, "product": x * y} - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -326,7 +326,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, A if name == "user_info": # Missing required 'age' field return {"name": "Alice"} - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -374,7 +374,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> tuple[list[ content = [TextContent(type="text", text=f"Analysis of: {arguments['text']}")] data = {"sentiment": "positive", "confidence": 0.95} return (content, data) - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: @@ -391,6 +391,47 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: assert result.structuredContent == {"sentiment": "positive", "confidence": 0.95} +@pytest.mark.anyio +async def test_tool_call_result(): + """Test returning ToolCallResult when no outputSchema is defined.""" + tools = [ + Tool( + name="get_info", + description="Get structured information", + inputSchema={ + "type": "object", + "properties": {}, + }, + # No outputSchema for direct return of tool call result + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> CallToolResult: + if name == "get_info": + return CallToolResult( + content=[TextContent(type="text", text="Results calculated")], + structuredContent={"status": "ok", "data": {"value": 42}}, + _meta={"some": "metadata"}, + ) + else: # pragma: no cover + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession) -> CallToolResult: + return await client_session.call_tool("get_info", {}) + + result = await run_tool_test(tools, call_tool_handler, test_callback) + + # Verify results + assert result is not None + assert not result.isError + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Results calculated" + assert isinstance(result.content[0], TextContent) + assert result.structuredContent == {"status": "ok", "data": {"value": 42}} + assert result.meta == {"some": "metadata"} + + @pytest.mark.anyio async def test_output_schema_type_validation(): """Test outputSchema validates types correctly.""" @@ -418,7 +459,7 @@ async def call_tool_handler(name: str, arguments: dict[str, Any]) -> dict[str, A if name == "stats": # Wrong type for 'count' - should be integer return {"count": "five", "average": 2.5, "items": ["a", "b"]} - else: + else: # pragma: no cover raise ValueError(f"Unknown tool: {name}") async def test_callback(client_session: ClientSession) -> CallToolResult: diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 33685f8f9e..f812c48777 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -20,7 +20,7 @@ async def test_lowlevel_server_tool_annotations(): # Create a tool with annotations @server.list_tools() - async def list_tools(): + async def list_tools(): # pragma: no cover return [ Tool( name="echo", @@ -47,7 +47,7 @@ async def list_tools(): async def message_handler( message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message # Server task diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index d97477e102..c31b90c557 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -18,7 +18,7 @@ def temp_file(): yield path try: path.unlink() - except FileNotFoundError: + except FileNotFoundError: # pragma: no cover pass diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 664867511c..34f9c6e28e 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -9,6 +9,7 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( @@ -34,7 +35,7 @@ async def test_server_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create a message handler to catch exceptions - async def message_handler( + async def message_handler( # pragma: no cover message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): @@ -54,11 +55,13 @@ async def run_server(): capabilities=ServerCapabilities(), ), ) as server_session: - async for message in server_session.incoming_messages: - if isinstance(message, Exception): + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, Exception): # pragma: no cover raise message - if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification): + if isinstance(message, ClientNotification) and isinstance( + message.root, InitializedNotification + ): # pragma: no branch received_initialized = True return @@ -74,7 +77,7 @@ async def run_server(): tg.start_soon(run_server) await client_session.initialize() - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover pass assert received_initialized @@ -94,7 +97,7 @@ async def test_server_capabilities(): # Add a prompts handler @server.list_prompts() - async def list_prompts() -> list[Prompt]: + async def list_prompts() -> list[Prompt]: # pragma: no cover return [] caps = server.get_capabilities(notification_options, experimental_capabilities) @@ -104,7 +107,7 @@ async def list_prompts() -> list[Prompt]: # Add a resources handler @server.list_resources() - async def list_resources() -> list[Resource]: + async def list_resources() -> list[Resource]: # pragma: no cover return [] caps = server.get_capabilities(notification_options, experimental_capabilities) @@ -114,7 +117,7 @@ async def list_resources() -> list[Resource]: # Add a complete handler @server.completion() - async def complete( + async def complete( # pragma: no cover ref: PromptReference | ResourceTemplateReference, argument: CompletionArgument, context: CompletionContext | None, @@ -150,11 +153,13 @@ async def run_server(): capabilities=ServerCapabilities(), ), ) as server_session: - async for message in server_session.incoming_messages: - if isinstance(message, Exception): + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, Exception): # pragma: no cover raise message - if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification): + if isinstance(message, types.ClientNotification) and isinstance( + message.root, InitializedNotification + ): # pragma: no branch received_initialized = True return @@ -234,12 +239,14 @@ async def run_server(): capabilities=ServerCapabilities(), ), ) as server_session: - async for message in server_session.incoming_messages: - if isinstance(message, Exception): + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, Exception): # pragma: no cover raise message # We should receive a ping request before initialization - if isinstance(message, RequestResponder) and isinstance(message.request.root, types.PingRequest): + if isinstance(message, RequestResponder) and isinstance( + message.request.root, types.PingRequest + ): # pragma: no branch # Respond to the ping with message: await message.respond(types.ServerResult(types.EmptyResult())) @@ -282,6 +289,184 @@ async def mock_client(): assert ping_response_id == 42 +@pytest.mark.anyio +async def test_create_message_tool_result_validation(): + """Test tool_use/tool_result validation in create_message.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ), + ) as session: + # Set up client params with sampling.tools capability for the test + session._client_params = types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()) + ), + clientInfo=types.Implementation(name="test", version="1.0"), + ) + + tool = types.Tool(name="test_tool", inputSchema={"type": "object"}) + text = types.TextContent(type="text", text="hello") + tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={}) + tool_result = types.ToolResultContent(type="tool_result", toolUseId="call_1", content=[]) + + # Case 1: tool_result mixed with other content + with pytest.raises(ValueError, match="only tool_result content"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage(role="user", content=[tool_result, text]), # mixed! + ], + max_tokens=100, + tools=[tool], + ) + + # Case 2: tool_result without previous message + with pytest.raises(ValueError, match="requires a previous message"): + await session.create_message( + messages=[types.SamplingMessage(role="user", content=tool_result)], + max_tokens=100, + tools=[tool], + ) + + # Case 3: tool_result without previous tool_use + with pytest.raises(ValueError, match="do not match any tool_use"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 4: mismatched tool IDs + with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage( + role="user", + content=types.ToolResultContent(type="tool_result", toolUseId="wrong_id", content=[]), + ), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 5: text-only message with tools (no tool_results) - passes validation + # Covers has_tool_results=False branch. + # We use move_on_after because validation happens synchronously before + # send_request, which would block indefinitely waiting for a response. + # The timeout lets validation pass, then cancels the blocked send. + with anyio.move_on_after(0.01): + await session.create_message( + messages=[types.SamplingMessage(role="user", content=text)], + max_tokens=100, + tools=[tool], + ) + + # Case 6: valid matching tool_result/tool_use IDs - passes validation + # Covers tool_use_ids == tool_result_ids branch. + # (see Case 5 comment for move_on_after explanation) + with anyio.move_on_after(0.01): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="assistant", content=tool_use), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + tools=[tool], + ) + + # Case 7: validation runs even without `tools` parameter + # (tool loop continuation may omit tools while containing tool_result) + with pytest.raises(ValueError, match="do not match any tool_use"): + await session.create_message( + messages=[ + types.SamplingMessage(role="user", content=text), + types.SamplingMessage(role="user", content=tool_result), + ], + max_tokens=100, + # Note: no tools parameter + ) + + # Case 8: empty messages list - skips validation entirely + # Covers the `if messages:` branch (line 280->302) + with anyio.move_on_after(0.01): + await session.create_message( + messages=[], + max_tokens=100, + ) + + +@pytest.mark.anyio +async def test_create_message_without_tools_capability(): + """Test that create_message raises McpError when tools are provided without capability.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ), + ) as session: + # Set up client params WITHOUT sampling.tools capability + session._client_params = types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities(sampling=types.SamplingCapability()), + clientInfo=types.Implementation(name="test", version="1.0"), + ) + + tool = types.Tool(name="test_tool", inputSchema={"type": "object"}) + text = types.TextContent(type="text", text="hello") + + # Should raise McpError when tools are provided but client lacks capability + with pytest.raises(McpError) as exc_info: + await session.create_message( + messages=[types.SamplingMessage(role="user", content=text)], + max_tokens=100, + tools=[tool], + ) + assert "does not support sampling tools capability" in exc_info.value.error.message + + # Should also raise McpError when tool_choice is provided + with pytest.raises(McpError) as exc_info: + await session.create_message( + messages=[types.SamplingMessage(role="user", content=text)], + max_tokens=100, + tool_choice=types.ToolChoice(mode="auto"), + ) + assert "does not support sampling tools capability" in exc_info.value.error.message + + @pytest.mark.anyio async def test_other_requests_blocked_before_initialization(): """Test that non-ping requests are still blocked before initialization.""" @@ -323,7 +508,7 @@ async def mock_client(): # Wait for the error response error_message = await server_to_client_receive.receive() - if isinstance(error_message.message.root, types.JSONRPCError): + if isinstance(error_message.message.root, types.JSONRPCError): # pragma: no branch error_response_received = True error_code = error_message.message.root.error.code diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py new file mode 100644 index 0000000000..b5388167ad --- /dev/null +++ b/tests/server/test_session_race_condition.py @@ -0,0 +1,155 @@ +""" +Test for race condition fix in initialization flow. + +This test verifies that requests can be processed immediately after +responding to InitializeRequest, without waiting for InitializedNotification. + +This is critical for HTTP transport where requests can arrive in any order. +""" + +import anyio +import pytest + +import mcp.types as types +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ServerCapabilities, Tool + + +@pytest.mark.anyio +async def test_request_immediately_after_initialize_response(): + """ + Test that requests are accepted immediately after initialize response. + + This reproduces the race condition in stateful HTTP mode where: + 1. Client sends InitializeRequest + 2. Server responds with InitializeResult + 3. Client immediately sends tools/list (before server receives InitializedNotification) + 4. Without fix: Server rejects with "Received request before initialization was complete" + 5. With fix: Server accepts and processes the request + + This test simulates the HTTP transport behavior where InitializedNotification + may arrive in a separate POST request after other requests. + """ + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](10) + + tools_list_success = False + error_received = None + + async def run_server(): + nonlocal tools_list_success + + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=ServerCapabilities( + tools=types.ToolsCapability(listChanged=False), + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: # pragma: no branch + if isinstance(message, Exception): # pragma: no cover + raise message + + # Handle tools/list request + if isinstance(message, RequestResponder): + if isinstance(message.request.root, types.ListToolsRequest): # pragma: no branch + tools_list_success = True + # Respond with a tool list + with message: + await message.respond( + types.ServerResult( + types.ListToolsResult( + tools=[ + Tool( + name="example_tool", + description="An example tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + ) + ) + ) + + # Handle InitializedNotification + if isinstance(message, types.ClientNotification): + if isinstance(message.root, types.InitializedNotification): # pragma: no branch + # Done - exit gracefully + return + + async def mock_client(): + nonlocal error_received + + # Step 1: Send InitializeRequest + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities(), + clientInfo=types.Implementation(name="test-client", version="1.0.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Step 2: Wait for InitializeResult + init_msg = await server_to_client_receive.receive() + assert isinstance(init_msg.message.root, types.JSONRPCResponse) + + # Step 3: Immediately send tools/list BEFORE InitializedNotification + # This is the race condition scenario + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/list", + ) + ) + ) + ) + + # Step 4: Check the response + tools_msg = await server_to_client_receive.receive() + if isinstance(tools_msg.message.root, types.JSONRPCError): # pragma: no cover + error_received = tools_msg.message.root.error.message + + # Step 5: Send InitializedNotification + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + tg.start_soon(mock_client) + + # With the PR fix: tools_list_success should be True, error_received should be None + # Without the fix: error_received would contain "Received request before initialization was complete" + assert tools_list_success, f"tools/list should have succeeded. Error received: {error_received}" + assert error_received is None, f"Expected no error, but got: {error_received}" diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index bdaec6bdba..010eaf6a25 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -3,7 +3,6 @@ import logging import multiprocessing import socket -import time import httpx import pytest @@ -17,6 +16,7 @@ from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool +from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" @@ -30,11 +30,11 @@ def server_port() -> int: @pytest.fixture -def server_url(http://github.com/server_port:%20int) -> str: +def server_url(http://github.com/server_port:%20int) -> str: # pragma: no cover return f"http://127.0.0.1:{server_port}" -class SecurityTestServer(Server): +class SecurityTestServer(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) @@ -42,7 +42,7 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover """Run the SSE server with specified security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) @@ -70,8 +70,8 @@ def start_server_process(port: int, security_settings: TransportSecuritySettings """Start server in a separate process.""" process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() - # Give server time to start - time.sleep(1) + # Wait for server to be ready to accept connections + wait_for_server(port) return process diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index a1d1792f88..13cdde3d61 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -29,7 +29,7 @@ async def test_stdio_server(): received_messages: list[JSONRPCMessage] = [] async with read_stream: async for message in read_stream: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message received_messages.append(message.message) if len(received_messages) == 2: diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 7a8551e5c6..6fcf08aa00 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -26,7 +26,7 @@ async def test_run_can_only_be_called_once(): # Second call should raise RuntimeError with pytest.raises(RuntimeError) as excinfo: async with manager.run(): - pass + pass # pragma: no cover assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value) @@ -66,10 +66,10 @@ async def test_handle_request_without_run_raises_error(): # Mock ASGI parameters scope = {"type": "http", "method": "POST", "path": "/test"} - async def receive(): + async def receive(): # pragma: no cover return {"type": "http.request", "body": b""} - async def send(message: Message): + async def send(message: Message): # pragma: no cover pass # Should raise error because run() hasn't been called @@ -114,7 +114,7 @@ async def mock_send(message: Message): "headers": [(b"content-type", b"application/json")], } - async def mock_receive(): + async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation @@ -122,13 +122,13 @@ async def mock_receive(): # Extract session ID from response headers session_id = None - for msg in sent_messages: - if msg["type"] == "http.response.start": - for header_name, header_value in msg.get("headers", []): + for msg in sent_messages: # pragma: no branch + if msg["type"] == "http.response.start": # pragma: no branch + for header_name, header_value in msg.get("headers", []): # pragma: no branch if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): session_id = header_value.decode() break - if session_id: # Break outer loop if session_id is found + if session_id: # Break outer loop if session_id is found # pragma: no branch break assert session_id is not None, "Session ID not found in response headers" @@ -163,7 +163,7 @@ async def mock_send(message: Message): # If an exception occurs, the transport might try to send an error response # For this test, we mostly care that the session is established enough # to get an ID - if message["type"] == "http.response.start" and message["status"] >= 500: + if message["type"] == "http.response.start" and message["status"] >= 500: # pragma: no cover pass # Expected if TestException propagates that far up the transport scope = { @@ -173,20 +173,20 @@ async def mock_send(message: Message): "headers": [(b"content-type", b"application/json")], } - async def mock_receive(): + async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation await manager.handle_request(scope, mock_receive, mock_send) session_id = None - for msg in sent_messages: - if msg["type"] == "http.response.start": - for header_name, header_value in msg.get("headers", []): + for msg in sent_messages: # pragma: no branch + if msg["type"] == "http.response.start": # pragma: no branch + for header_name, header_value in msg.get("headers", []): # pragma: no branch if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): session_id = header_value.decode() break - if session_id: # Break outer loop if session_id is found + if session_id: # Break outer loop if session_id is found # pragma: no branch break assert session_id is not None, "Session ID not found in response headers" diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index b9cd83dc1b..a637b1dce0 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -3,7 +3,6 @@ import logging import multiprocessing import socket -import time from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -18,6 +17,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool +from tests.test_helpers import wait_for_server logger = logging.getLogger(__name__) SERVER_NAME = "test_streamable_http_security_server" @@ -31,11 +31,11 @@ def server_port() -> int: @pytest.fixture -def server_url(http://github.com/server_port:%20int) -> str: +def server_url(http://github.com/server_port:%20int) -> str: # pragma: no cover return f"http://127.0.0.1:{server_port}" -class SecurityTestServer(Server): +class SecurityTestServer(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) @@ -43,7 +43,7 @@ async def on_list_tools(self) -> list[Tool]: return [] -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover """Run the StreamableHTTP server with specified security settings.""" app = SecurityTestServer() @@ -77,8 +77,8 @@ def start_server_process(port: int, security_settings: TransportSecuritySettings """Start server in a separate process.""" process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) process.start() - # Give server time to start - time.sleep(1) + # Wait for server to be ready to accept connections + wait_for_server(port) return process diff --git a/tests/server/test_validation.py b/tests/server/test_validation.py new file mode 100644 index 0000000000..56044460df --- /dev/null +++ b/tests/server/test_validation.py @@ -0,0 +1,141 @@ +"""Tests for server validation functions.""" + +import pytest + +from mcp.server.validation import ( + check_sampling_tools_capability, + validate_sampling_tools, + validate_tool_use_result_messages, +) +from mcp.shared.exceptions import McpError +from mcp.types import ( + ClientCapabilities, + SamplingCapability, + SamplingMessage, + SamplingToolsCapability, + TextContent, + Tool, + ToolChoice, + ToolResultContent, + ToolUseContent, +) + + +class TestCheckSamplingToolsCapability: + """Tests for check_sampling_tools_capability function.""" + + def test_returns_false_when_caps_none(self) -> None: + """Returns False when client_caps is None.""" + assert check_sampling_tools_capability(None) is False + + def test_returns_false_when_sampling_none(self) -> None: + """Returns False when client_caps.sampling is None.""" + caps = ClientCapabilities() + assert check_sampling_tools_capability(caps) is False + + def test_returns_false_when_tools_none(self) -> None: + """Returns False when client_caps.sampling.tools is None.""" + caps = ClientCapabilities(sampling=SamplingCapability()) + assert check_sampling_tools_capability(caps) is False + + def test_returns_true_when_tools_present(self) -> None: + """Returns True when sampling.tools is present.""" + caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + assert check_sampling_tools_capability(caps) is True + + +class TestValidateSamplingTools: + """Tests for validate_sampling_tools function.""" + + def test_no_error_when_tools_none(self) -> None: + """No error when tools and tool_choice are None.""" + validate_sampling_tools(None, None, None) # Should not raise + + def test_raises_when_tools_provided_but_no_capability(self) -> None: + """Raises McpError when tools provided but client doesn't support.""" + tool = Tool(name="test", inputSchema={"type": "object"}) + with pytest.raises(McpError) as exc_info: + validate_sampling_tools(None, [tool], None) + assert "sampling tools capability" in str(exc_info.value) + + def test_raises_when_tool_choice_provided_but_no_capability(self) -> None: + """Raises McpError when tool_choice provided but client doesn't support.""" + with pytest.raises(McpError) as exc_info: + validate_sampling_tools(None, None, ToolChoice(mode="auto")) + assert "sampling tools capability" in str(exc_info.value) + + def test_no_error_when_capability_present(self) -> None: + """No error when client has sampling.tools capability.""" + caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + tool = Tool(name="test", inputSchema={"type": "object"}) + validate_sampling_tools(caps, [tool], ToolChoice(mode="auto")) # Should not raise + + +class TestValidateToolUseResultMessages: + """Tests for validate_tool_use_result_messages function.""" + + def test_no_error_for_empty_messages(self) -> None: + """No error when messages list is empty.""" + validate_tool_use_result_messages([]) # Should not raise + + def test_no_error_for_simple_text_messages(self) -> None: + """No error for simple text messages.""" + messages = [ + SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), + SamplingMessage(role="assistant", content=TextContent(type="text", text="Hi")), + ] + validate_tool_use_result_messages(messages) # Should not raise + + def test_raises_when_tool_result_mixed_with_other_content(self) -> None: + """Raises when tool_result is mixed with other content types.""" + messages = [ + SamplingMessage( + role="user", + content=[ + ToolResultContent(type="tool_result", toolUseId="123"), + TextContent(type="text", text="also this"), + ], + ), + ] + with pytest.raises(ValueError, match="only tool_result content"): + validate_tool_use_result_messages(messages) + + def test_raises_when_tool_result_without_previous_tool_use(self) -> None: + """Raises when tool_result appears without preceding tool_use.""" + messages = [ + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", toolUseId="123"), + ), + ] + with pytest.raises(ValueError, match="previous message containing tool_use"): + validate_tool_use_result_messages(messages) + + def test_raises_when_tool_result_ids_dont_match_tool_use(self) -> None: + """Raises when tool_result IDs don't match tool_use IDs.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", toolUseId="tool-2"), + ), + ] + with pytest.raises(ValueError, match="do not match"): + validate_tool_use_result_messages(messages) + + def test_no_error_when_tool_result_matches_tool_use(self) -> None: + """No error when tool_result IDs match tool_use IDs.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", toolUseId="tool-1"), + ), + ] + validate_tool_use_result_messages(messages) # Should not raise diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py new file mode 100644 index 0000000000..8845dfe781 --- /dev/null +++ b/tests/shared/test_exceptions.py @@ -0,0 +1,159 @@ +"""Tests for MCP exception classes.""" + +import pytest + +from mcp.shared.exceptions import McpError, UrlElicitationRequiredError +from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData + + +class TestUrlElicitationRequiredError: + """Tests for UrlElicitationRequiredError exception class.""" + + def test_create_with_single_elicitation(self) -> None: + """Test creating error with a single elicitation.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert error.error.code == URL_ELICITATION_REQUIRED + assert error.error.message == "URL elicitation required" + assert len(error.elicitations) == 1 + assert error.elicitations[0].elicitationId == "test-123" + + def test_create_with_multiple_elicitations(self) -> None: + """Test creating error with multiple elicitations uses plural message.""" + elicitations = [ + ElicitRequestURLParams( + mode="url", + message="Auth 1", + url="https://example.com/auth1", + elicitationId="test-1", + ), + ElicitRequestURLParams( + mode="url", + message="Auth 2", + url="https://example.com/auth2", + elicitationId="test-2", + ), + ] + error = UrlElicitationRequiredError(elicitations) + + assert error.error.message == "URL elicitations required" # Plural + assert len(error.elicitations) == 2 + + def test_custom_message(self) -> None: + """Test creating error with a custom message.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation], message="Custom message") + + assert error.error.message == "Custom message" + + def test_from_error_data(self) -> None: + """Test reconstructing error from ErrorData.""" + error_data = ErrorData( + code=URL_ELICITATION_REQUIRED, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Auth required", + "url": "https://example.com/auth", + "elicitationId": "test-123", + } + ] + }, + ) + + error = UrlElicitationRequiredError.from_error(error_data) + + assert len(error.elicitations) == 1 + assert error.elicitations[0].elicitationId == "test-123" + assert error.elicitations[0].url == "https://example.com/auth" + + def test_from_error_data_wrong_code(self) -> None: + """Test that from_error raises ValueError for wrong error code.""" + error_data = ErrorData( + code=-32600, # Wrong code + message="Some other error", + data={}, + ) + + with pytest.raises(ValueError, match="Expected error code"): + UrlElicitationRequiredError.from_error(error_data) + + def test_serialization_roundtrip(self) -> None: + """Test that error can be serialized and reconstructed.""" + original = UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + ] + ) + + # Simulate serialization over wire + error_data = original.error + + # Reconstruct + reconstructed = UrlElicitationRequiredError.from_error(error_data) + + assert reconstructed.elicitations[0].elicitationId == original.elicitations[0].elicitationId + assert reconstructed.elicitations[0].url == original.elicitations[0].url + assert reconstructed.elicitations[0].message == original.elicitations[0].message + + def test_error_data_contains_elicitations(self) -> None: + """Test that error data contains properly serialized elicitations.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Please authenticate", + url="https://example.com/oauth", + elicitationId="oauth-flow-1", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert error.error.data is not None + assert "elicitations" in error.error.data + elicit_data = error.error.data["elicitations"][0] + assert elicit_data["mode"] == "url" + assert elicit_data["message"] == "Please authenticate" + assert elicit_data["url"] == "https://example.com/oauth" + assert elicit_data["elicitationId"] == "oauth-flow-1" + + def test_inherits_from_mcp_error(self) -> None: + """Test that UrlElicitationRequiredError inherits from McpError.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert isinstance(error, McpError) + assert isinstance(error, Exception) + + def test_exception_message(self) -> None: + """Test that exception message is set correctly.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitationId="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + # The exception's string representation should match the message + assert str(error) == "URL elicitation required" diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 16bd6cb930..ca4368e9f8 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -13,7 +13,7 @@ def mcp_server() -> Server: server = Server(name="test_server") @server.list_resources() - async def handle_list_resources(): + async def handle_list_resources(): # pragma: no cover return [ Resource( uri=AnyUrl("memory://test"), diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 600972272d..1552711d2e 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -41,7 +41,7 @@ async def run_server(): async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) - except Exception as e: + except Exception as e: # pragma: no cover raise e # Track progress updates @@ -91,10 +91,10 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ if arguments and "_meta" in arguments: progressToken = arguments["_meta"]["progressToken"] - if not progressToken: + if not progressToken: # pragma: no cover raise ValueError("Empty progress token received") - if progressToken != client_progress_token: + if progressToken != client_progress_token: # pragma: no cover raise ValueError("Server sending back incorrect progressToken") # Send progress notifications @@ -119,22 +119,22 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ message="Server progress 100%", ) - else: + else: # pragma: no cover raise ValueError("Progress token not sent.") return [types.TextContent(type="text", text="Tool executed successfully")] - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {name}") # pragma: no cover # Client message handler to store progress notifications async def handle_client_message( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message - if isinstance(message, types.ServerNotification): - if isinstance(message.root, types.ProgressNotification): + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.ProgressNotification): # pragma: no branch params = message.root.params client_progress_updates.append( { @@ -248,14 +248,14 @@ async def run_server(): async for message in server_session.incoming_messages: try: await server._handle_message(message, server_session, {}) - except Exception as e: + except Exception as e: # pragma: no cover raise e # Client message handler async def handle_client_message( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, Exception): + if isinstance(message, Exception): # pragma: no cover raise message # run client session @@ -335,7 +335,9 @@ def mock_log_error(msg: str, *args: Any) -> None: logged_errors.append(msg % args if args else msg) # Create a progress callback that raises an exception - async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: + async def failing_progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: # pragma: no cover raise ValueError("Progress callback failed!") # Create a server with a tool that sends progress notifications @@ -352,7 +354,7 @@ async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent] message="Halfway done", ) return [types.TextContent(type="text", text="progress_result")] - raise ValueError(f"Unknown tool: {name}") + raise ValueError(f"Unknown tool: {name}") # pragma: no cover @server.list_tools() async def handle_list_tools() -> list[types.Tool]: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 320693786c..b355a4bf2d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -9,12 +9,18 @@ from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session +from mcp.shared.message import SessionMessage from mcp.types import ( CancelledNotification, CancelledNotificationParams, ClientNotification, ClientRequest, EmptyResult, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, TextContent, ) @@ -66,8 +72,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ request_id = server.request_context.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel - return [] - raise ValueError(f"Unknown tool: {name}") + return [] # pragma: no cover + raise ValueError(f"Unknown tool: {name}") # pragma: no cover # Register the tool so it shows up in list_tools @server.list_tools() @@ -93,7 +99,7 @@ async def make_request(client_session: ClientSession): ), types.CallToolResult, ) - pytest.fail("Request should have been cancelled") + pytest.fail("Request should have been cancelled") # pragma: no cover except McpError as e: # Expected - request was cancelled assert "Request cancelled" in str(e) @@ -122,6 +128,169 @@ async def make_request(client_session: ClientSession): await ev_cancelled.wait() +@pytest.mark.anyio +async def test_response_id_type_mismatch_string_to_int(): + """ + Test that responses with string IDs are correctly matched to requests sent with + integer IDs. + + This handles the case where a server returns "id": "0" (string) but the client + sent "id": 0 (integer). Without ID type normalization, this would cause a timeout. + """ + ev_response_received = anyio.Event() + result_holder: list[types.EmptyResult] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive a request and respond with a string ID instead of integer.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + root = message.message.root + assert isinstance(root, JSONRPCRequest) + # Get the original request ID (which is an integer) + request_id = root.id + assert isinstance(request_id, int), f"Expected int, got {type(request_id)}" + + # Respond with the ID as a string (simulating a buggy server) + response = JSONRPCResponse( + jsonrpc="2.0", + id=str(request_id), # Convert to string to simulate mismatch + result={}, + ) + await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + + async def make_request(client_session: ClientSession): + nonlocal result_holder + # Send a ping request (uses integer ID internally) + result = await client_session.send_ping() + result_holder.append(result) + ev_response_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): + await ev_response_received.wait() + + assert len(result_holder) == 1 + assert isinstance(result_holder[0], EmptyResult) + + +@pytest.mark.anyio +async def test_error_response_id_type_mismatch_string_to_int(): + """ + Test that error responses with string IDs are correctly matched to requests + sent with integer IDs. + + This handles the case where a server returns an error with "id": "0" (string) + but the client sent "id": 0 (integer). + """ + ev_error_received = anyio.Event() + error_holder: list[McpError] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive a request and respond with an error using a string ID.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + root = message.message.root + assert isinstance(root, JSONRPCRequest) + request_id = root.id + assert isinstance(request_id, int) + + # Respond with an error, using the ID as a string + error_response = JSONRPCError( + jsonrpc="2.0", + id=str(request_id), # Convert to string to simulate mismatch + error=ErrorData(code=-32600, message="Test error"), + ) + await server_write.send(SessionMessage(message=JSONRPCMessage(error_response))) + + async def make_request(client_session: ClientSession): + nonlocal error_holder + try: + await client_session.send_ping() + pytest.fail("Expected McpError to be raised") # pragma: no cover + except McpError as e: + error_holder.append(e) + ev_error_received.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): + await ev_error_received.wait() + + assert len(error_holder) == 1 + assert "Test error" in str(error_holder[0]) + + +@pytest.mark.anyio +async def test_response_id_non_numeric_string_no_match(): + """ + Test that responses with non-numeric string IDs don't incorrectly match + integer request IDs. + + If a server returns "id": "abc" (non-numeric string), it should not match + a request sent with "id": 0 (integer). + """ + ev_timeout = anyio.Event() + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive a request and respond with a non-numeric string ID.""" + message = await server_read.receive() + assert isinstance(message, SessionMessage) + + # Respond with a non-numeric string ID (should not match) + response = JSONRPCResponse( + jsonrpc="2.0", + id="not_a_number", # Non-numeric string + result={}, + ) + await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + + async def make_request(client_session: ClientSession): + try: + # Use a short timeout since we expect this to fail + await client_session.send_request( + ClientRequest(types.PingRequest()), + types.EmptyResult, + request_read_timeout_seconds=0.5, + ) + pytest.fail("Expected timeout") # pragma: no cover + except McpError as e: + assert "Timed out" in str(e) + ev_timeout.set() + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(2): + await ev_timeout.wait() + + @pytest.mark.anyio async def test_connection_closed(): """ @@ -141,7 +310,7 @@ async def make_request(client_session: ClientSession): try: # any request will do await client_session.initialize() - pytest.fail("Request should have errored") + pytest.fail("Request should have errored") # pragma: no cover except McpError as e: # Expected - request errored assert "Connection closed" in str(e) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7b0d89cb42..7604450f81 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -4,11 +4,13 @@ import time from collections.abc import AsyncGenerator, Generator from typing import Any +from unittest.mock import AsyncMock, MagicMock, Mock, patch import anyio import httpx import pytest import uvicorn +from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from pydantic import AnyUrl from starlette.applications import Starlette @@ -16,9 +18,10 @@ from starlette.responses import Response from starlette.routing import Mount, Route +import mcp.client.sse import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.sse import sse_client +from mcp.client.sse import _extract_session_id_from_endpoint, sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings @@ -26,12 +29,16 @@ from mcp.types import ( EmptyResult, ErrorData, + Implementation, InitializeResult, + JSONRPCResponse, ReadResourceResult, + ServerCapabilities, TextContent, TextResourceContents, Tool, ) +from tests.test_helpers import wait_for_server SERVER_NAME = "test_server_for_SSE" @@ -49,7 +56,7 @@ def server_url(http://github.com/server_port:%20int) -> str: # Test server implementation -class ServerTest(Server): +class ServerTest(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) @@ -80,7 +87,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] # Test fixtures -def make_server_app() -> Starlette: +def make_server_app() -> Starlette: # pragma: no cover """Create test Starlette app with SSE transport""" # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( @@ -104,7 +111,7 @@ async def handle_sse(request: Request) -> Response: return app -def run_server(server_port: int) -> None: +def run_server(server_port: int) -> None: # pragma: no cover app = make_server_app() server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") @@ -123,19 +130,8 @@ def server(server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 print("waiting for server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + wait_for_server(server_port) yield @@ -143,7 +139,7 @@ def server(server_port: int) -> Generator[None, None, None]: # Signal the server to stop proc.kill() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover print("server process failed to terminate") @@ -166,7 +162,7 @@ async def connection_test() -> None: assert response.headers["content-type"] == "text/event-stream; charset=utf-8" line_number = 0 - async for line in response.aiter_lines(): + async for line in response.aiter_lines(): # pragma: no branch if line_number == 0: assert line == "event: endpoint" elif line_number == 1: @@ -194,6 +190,57 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non assert isinstance(ping_result, EmptyResult) +@pytest.mark.anyio +async def test_sse_client_on_session_created(server: None, server_url: str) -> None: + captured_session_id: str | None = None + + def on_session_created(session_id: str) -> None: + nonlocal captured_session_id + captured_session_id = session_id + + async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + assert captured_session_id is not None + assert len(captured_session_id) > 0 + + +@pytest.mark.parametrize( + "endpoint_url,expected", + [ + ("/messages?sessionId=abc123", "abc123"), + ("/messages?session_id=def456", "def456"), + ("/messages?sessionId=abc&session_id=def", "abc"), + ("/messages?other=value", None), + ("/messages", None), + ("", None), + ], +) +def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + assert _extract_session_id_from_endpoint(endpoint_url) == expected + + +@pytest.mark.anyio +async def test_sse_client_on_session_created_not_called_when_no_session_id( + server: None, server_url: str, monkeypatch: pytest.MonkeyPatch +) -> None: + callback_mock = Mock() + + def mock_extract(url: str) -> None: + return None + + monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) + + async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + callback_mock.assert_not_called() + + @pytest.fixture async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: @@ -224,7 +271,7 @@ async def test_sse_client_exception_handling( @pytest.mark.anyio @pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") -async def test_sse_client_timeout( +async def test_sse_client_timeout( # pragma: no cover initialized_sse_client_session: ClientSession, ) -> None: session = initialized_sse_client_session @@ -242,7 +289,7 @@ async def test_sse_client_timeout( pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: +def run_mounted_server(server_port: int) -> None: # pragma: no cover app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) @@ -262,19 +309,8 @@ def mounted_server(server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 print("waiting for server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + wait_for_server(server_port) yield @@ -282,7 +318,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]: # Signal the server to stop proc.kill() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover print("server process failed to terminate") @@ -301,7 +337,7 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser # Test server with request context that returns headers in the response -class RequestContextServer(Server[object, Request]): +class RequestContextServer(Server[object, Request]): # pragma: no cover def __init__(self): super().__init__("request_context_server") @@ -343,7 +379,7 @@ async def handle_list_tools() -> list[Tool]: ] -def run_context_server(server_port: int) -> None: +def run_context_server(server_port: int) -> None: # pragma: no cover """Run a server that captures request context""" # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( @@ -377,26 +413,15 @@ def context_server(server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 print("waiting for context server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Context server failed to start after {max_attempts} attempts") + wait_for_server(server_port) yield print("killing context server") proc.kill() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover print("context server process failed to terminate") @@ -511,3 +536,69 @@ def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result sse = SseServerTransport(endpoint) assert sse._endpoint == expected_result assert sse._endpoint.startswith("/") + + +# ResourceWarning filter: When mocking aconnect_sse, the sse_client's internal task +# group doesn't receive proper cancellation signals, so the sse_reader task's finally +# block (which closes read_stream_writer) doesn't execute. This is a test artifact - +# the actual code path (`if not sse.data: continue`) IS exercised and works correctly. +# Production code with real SSE connections cleans up properly. +@pytest.mark.filterwarnings("ignore::ResourceWarning") +@pytest.mark.anyio +async def test_sse_client_handles_empty_keepalive_pings() -> None: + """Test that SSE client properly handles empty data lines (keep-alive pings). + + Per the MCP spec (Streamable HTTP transport): "The server SHOULD immediately + send an SSE event consisting of an event ID and an empty data field in order + to prime the client to reconnect." + + This test mocks the SSE event stream to include empty "message" events and + verifies the client skips them without crashing. + """ + # Build a proper JSON-RPC response using types (not hardcoded strings) + init_result = InitializeResult( + protocolVersion="2024-11-05", + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="test", version="1.0"), + ) + response = JSONRPCResponse( + jsonrpc="2.0", + id=1, + result=init_result.model_dump(by_alias=True, exclude_none=True), + ) + response_json = response.model_dump_json(by_alias=True, exclude_none=True) + + # Create mock SSE events using httpx_sse's ServerSentEvent + async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: + # First: endpoint event + yield ServerSentEvent(event="endpoint", data="/messages/?session_id=abc123") + # Empty data keep-alive ping - this is what we're testing + yield ServerSentEvent(event="message", data="") + # Real JSON-RPC response + yield ServerSentEvent(event="message", data=response_json) + + mock_event_source = MagicMock() + mock_event_source.aiter_sse.return_value = mock_aiter_sse() + mock_event_source.response = MagicMock() + mock_event_source.response.raise_for_status = MagicMock() + + mock_aconnect_sse = MagicMock() + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock())) + + with ( + patch("mcp.client.sse.create_mcp_http_client", return_value=mock_client), + patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse), + ): + async with sse_client("http://test/sse") as (read_stream, _): + # Read the message - should skip the empty one and get the real response + msg = await read_stream.receive() + # If we get here without error, the empty message was skipped successfully + assert not isinstance(msg, Exception) + assert isinstance(msg.message.root, types.JSONRPCResponse) + assert msg.message.root.id == 1 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 55800da33e..e95c309fbc 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,12 +10,14 @@ import time from collections.abc import Generator from typing import Any +from unittest.mock import MagicMock import anyio import httpx import pytest import requests import uvicorn +from httpx_sse import ServerSentEvent from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request @@ -23,7 +25,11 @@ import mcp.types as types from mcp.client.session import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import ( + StreamableHTTPTransport, + streamable_http_client, + streamablehttp_client, # pyright: ignore[reportDeprecated] +) from mcp.server import Server from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -38,11 +44,20 @@ ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.message import ClientMessageMetadata +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool +from mcp.types import ( + InitializeResult, + JSONRPCMessage, + JSONRPCRequest, + TextContent, + TextResourceContents, + Tool, +) +from tests.test_helpers import wait_for_server # Test constants SERVER_NAME = "test_streamable_http_server" @@ -60,7 +75,7 @@ # Helper functions -def extract_protocol_version_from_sse(response: requests.Response) -> str: +def extract_protocol_version_from_sse(response: requests.Response) -> str: # pragma: no cover """Extract the negotiated protocol version from an SSE initialization response.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): @@ -75,17 +90,19 @@ class SimpleEventStore(EventStore): """Simple in-memory event store for testing.""" def __init__(self): - self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] + self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = [] self._event_id_counter = 0 - async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: + async def store_event( # pragma: no cover + self, stream_id: StreamId, message: types.JSONRPCMessage | None + ) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( + async def replay_events_after( # pragma: no cover self, last_event_id: EventId, send_callback: EventCallback, @@ -108,13 +125,15 @@ async def replay_events_after( # Replay only events from the same stream with ID > last_event_id for stream_id, event_id, message in self._events: if stream_id == target_stream_id and int(event_id) > last_event_id_int: - await send_callback(EventMessage(message, event_id)) + # Skip priming events (None message) + if message is not None: + await send_callback(EventMessage(message, event_id)) return target_stream_id # Test server implementation that follows MCP protocol -class ServerTest(Server): +class ServerTest(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) self._lock = None # Will be initialized in async context @@ -163,6 +182,32 @@ async def handle_list_tools() -> list[Tool]: description="A tool that releases the lock", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="tool_with_stream_close", + description="A tool that closes SSE stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_notifications_and_close", + description="Tool that sends notification1, closes stream, sends notification2, notification3", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="tool_with_multiple_stream_closes", + description="Tool that closes SSE stream multiple times during execution", + inputSchema={ + "type": "object", + "properties": { + "checkpoints": {"type": "integer", "default": 3}, + "sleep_time": {"type": "number", "default": 0.2}, + }, + }, + ), + Tool( + name="tool_with_standalone_stream_close", + description="Tool that closes standalone GET stream mid-operation", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() @@ -210,7 +255,11 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] ) # Return the sampling result in the tool response - response = sampling_result.content.text if sampling_result.content.type == "text" else None + # Since we're not passing tools param, result.content is single content + if sampling_result.content.type == "text": + response = sampling_result.content.text + else: + response = str(sampling_result.content) return [ TextContent( type="text", @@ -251,15 +300,107 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] self._lock.set() return [TextContent(type="text", text="Lock released")] + elif name == "tool_with_stream_close": + # Send notification before closing + await ctx.session.send_log_message( + level="info", + data="Before close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + # Close SSE stream (triggers client reconnect) + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Continue processing (events stored in event_store) + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="After close", + logger="stream_close_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="Done")] + + elif name == "tool_with_multiple_notifications_and_close": + # Send notification1 + await ctx.session.send_log_message( + level="info", + data="notification1", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + # Close SSE stream + assert ctx.close_sse_stream is not None + await ctx.close_sse_stream() + # Send notification2, notification3 (stored in event_store) + await anyio.sleep(0.1) + await ctx.session.send_log_message( + level="info", + data="notification2", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + await ctx.session.send_log_message( + level="info", + data="notification3", + logger="multi_notif_tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="All notifications sent")] + + elif name == "tool_with_multiple_stream_closes": + num_checkpoints = args.get("checkpoints", 3) + sleep_time = args.get("sleep_time", 0.2) + + for i in range(num_checkpoints): + await ctx.session.send_log_message( + level="info", + data=f"checkpoint_{i}", + logger="multi_close_tool", + related_request_id=ctx.request_id, + ) + + if ctx.close_sse_stream: + await ctx.close_sse_stream() + + await anyio.sleep(sleep_time) + + return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")] + + elif name == "tool_with_standalone_stream_close": + # Test for GET stream reconnection + # 1. Send unsolicited notification via GET stream (no related_request_id) + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_1/")) + + # Small delay to ensure notification is flushed before closing + await anyio.sleep(0.1) + + # 2. Close the standalone GET stream + if ctx.close_standalone_sse_stream: + await ctx.close_standalone_sse_stream() + + # 3. Wait for client to reconnect (uses retry_interval from server, default 1000ms) + await anyio.sleep(1.5) + + # 4. Send another notification on the new GET stream connection + await ctx.session.send_resource_updated(uri=AnyUrl("http://notification_2/")) + + return [TextContent(type="text", text="Standalone stream close test done")] + return [TextContent(type="text", text=f"Called {name}")] -def create_app(is_json_response_enabled: bool = False, event_store: EventStore | None = None) -> Starlette: +def create_app( + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, +) -> Starlette: # pragma: no cover """Create a Starlette application for testing using the session manager. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. + retry_interval: Retry interval in milliseconds for SSE polling. """ # Create server instance server = ServerTest() @@ -273,6 +414,7 @@ def create_app(is_json_response_enabled: bool = False, event_store: EventStore | event_store=event_store, json_response=is_json_response_enabled, security_settings=security_settings, + retry_interval=retry_interval, ) # Create an ASGI application that uses the session manager @@ -287,16 +429,22 @@ def create_app(is_json_response_enabled: bool = False, event_store: EventStore | return app -def run_server(port: int, is_json_response_enabled: bool = False, event_store: EventStore | None = None) -> None: +def run_server( + port: int, + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, +) -> None: # pragma: no cover """Run the test server. Args: port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. event_store: Optional event store for testing resumability. + retry_interval: Retry interval in milliseconds for SSE polling. """ - app = create_app(is_json_response_enabled, event_store) + app = create_app(is_json_response_enabled, event_store, retry_interval) # Configure server config = uvicorn.Config( app=app, @@ -344,18 +492,7 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", basic_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + wait_for_server(basic_server_port) yield @@ -382,27 +519,16 @@ def event_server_port() -> int: def event_server( event_server_port: int, event_store: SimpleEventStore ) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store enabled.""" + """Start a server with event store and retry_interval enabled.""" proc = multiprocessing.Process( target=run_server, - kwargs={"port": event_server_port, "event_store": event_store}, + kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, daemon=True, ) proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", event_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + wait_for_server(event_server_port) yield event_store, f"http://127.0.0.1:{event_server_port}" @@ -422,18 +548,7 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", json_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + wait_for_server(json_server_port) yield @@ -693,6 +808,51 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" +def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): + """Test that json_response servers only require application/json in Accept header.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): + """Test that json_response servers reject requests without Accept header.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): + """Test that json_response servers reject requests with incorrect Accept header.""" + mcp_url = f"{json_server_url}/mcp" + # Test with only text/event-stream (wrong for JSON server) + response = requests.post( + mcp_url, + headers={ + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session @@ -714,8 +874,8 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): # Extract negotiated protocol version from SSE response init_data = None assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): - if line.startswith("data: "): + for line in init_response.text.splitlines(): # pragma: no branch + if line.startswith("data: "): # pragma: no cover init_data = json.loads(line[6:]) break assert init_data is not None @@ -774,8 +934,8 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Extract negotiated protocol version from SSE response init_data = None assert init_response.headers.get("Content-Type") == "text/event-stream" - for line in init_response.text.splitlines(): - if line.startswith("data: "): + for line in init_response.text.splitlines(): # pragma: no branch + if line.startswith("data: "): # pragma: no cover init_data = json.loads(line[6:]) break assert init_data is not None @@ -808,7 +968,7 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): +async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @@ -817,7 +977,7 @@ async def http_client(basic_server: None, basic_server_url: str): @pytest.fixture async def initialized_client_session(basic_server: None, basic_server_url: str): """Create initialized StreamableHTTP client session.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -831,9 +991,9 @@ async def initialized_client_session(basic_server: None, basic_server_url: str): @pytest.mark.anyio -async def test_streamablehttp_client_basic_connection(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): """Test basic client connection with initialization.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -849,7 +1009,7 @@ async def test_streamablehttp_client_basic_connection(basic_server: None, basic_ @pytest.mark.anyio -async def test_streamablehttp_client_resource_read(initialized_client_session: ClientSession): +async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession): """Test client resource read functionality.""" response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource")) assert len(response.contents) == 1 @@ -859,11 +1019,11 @@ async def test_streamablehttp_client_resource_read(initialized_client_session: C @pytest.mark.anyio -async def test_streamablehttp_client_tool_invocation(initialized_client_session: ClientSession): +async def test_streamable_http_client_tool_invocation(initialized_client_session: ClientSession): """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 10 assert tools.tools[0].name == "test_tool" # Call the tool @@ -874,7 +1034,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session: @pytest.mark.anyio -async def test_streamablehttp_client_error_handling(initialized_client_session: ClientSession): +async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession): """Test error handling in client.""" with pytest.raises(McpError) as exc_info: await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error")) @@ -883,9 +1043,9 @@ async def test_streamablehttp_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamablehttp_client_session_persistence(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): """Test that session ID persists across requests.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -900,7 +1060,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 10 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -911,9 +1071,9 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas @pytest.mark.anyio -async def test_streamablehttp_client_json_response(json_response_server: None, json_server_url: str): +async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): """Test client with JSON response mode.""" - async with streamablehttp_client(f"{json_server_url}/mcp") as ( + async with streamable_http_client(f"{json_server_url}/mcp") as ( read_stream, write_stream, _, @@ -929,7 +1089,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 10 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -939,21 +1099,20 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j @pytest.mark.anyio -async def test_streamablehttp_client_get_stream(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" import mcp.types as types - from mcp.shared.session import RequestResponder notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications - async def message_handler( + async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, types.ServerNotification): + if isinstance(message, types.ServerNotification): # pragma: no branch notifications_received.append(message) - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -972,7 +1131,7 @@ async def message_handler( # Verify the notification is a ResourceUpdatedNotification resource_update_found = False for notif in notifications_received: - if isinstance(notif.root, types.ResourceUpdatedNotification): + if isinstance(notif.root, types.ResourceUpdatedNotification): # pragma: no branch assert str(notif.root.params.uri) == "http://test_resource/" resource_update_found = True @@ -980,13 +1139,13 @@ async def message_handler( @pytest.mark.anyio -async def test_streamablehttp_client_session_termination(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): """Test client session termination functionality.""" captured_session_id = None - # Create the streamablehttp_client with a custom httpx client to capture headers - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + # Create the streamable_http_client with a custom httpx client to capture headers + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, get_session_id, @@ -1000,28 +1159,29 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 10 - headers: dict[str, str] = {} - if captured_session_id: + headers: dict[str, str] = {} # pragma: no cover + if captured_session_id: # pragma: no cover headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( - read_stream, - write_stream, - _, - ): - async with ClientSession(read_stream, write_stream) as session: - # Attempt to make a request after termination - with pytest.raises( - McpError, - match="Session terminated", - ): - await session.list_tools() + async with create_mcp_http_client(headers=headers) as httpx_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + # Attempt to make a request after termination + with pytest.raises( # pragma: no branch + McpError, + match="Session terminated", + ): + await session.list_tools() @pytest.mark.anyio -async def test_streamablehttp_client_session_termination_204( +async def test_streamable_http_client_session_termination_204( basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch ): """Test client session termination functionality with a 204 response. @@ -1051,8 +1211,8 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt captured_session_id = None - # Create the streamablehttp_client with a custom httpx client to capture headers - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + # Create the streamable_http_client with a custom httpx client to capture headers + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, get_session_id, @@ -1066,28 +1226,29 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 6 + assert len(tools.tools) == 10 - headers: dict[str, str] = {} - if captured_session_id: + headers: dict[str, str] = {} # pragma: no cover + if captured_session_id: # pragma: no cover headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( - read_stream, - write_stream, - _, - ): - async with ClientSession(read_stream, write_stream) as session: - # Attempt to make a request after termination - with pytest.raises( - McpError, - match="Session terminated", - ): - await session.list_tools() + async with create_mcp_http_client(headers=headers) as httpx_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + # Attempt to make a request after termination + with pytest.raises( # pragma: no branch + McpError, + match="Session terminated", + ): + await session.list_tools() @pytest.mark.anyio -async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]): +async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]): """Test client session resumption using sync primitives for reliable coordination.""" _, server_url = event_server @@ -1098,13 +1259,13 @@ async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventS captured_protocol_version = None first_notification_received = False - async def message_handler( + async def message_handler( # pragma: no branch message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, types.ServerNotification): + if isinstance(message, types.ServerNotification): # pragma: no branch captured_notifications.append(message) # Look for our first notification - if isinstance(message.root, types.LoggingMessageNotification): + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch if message.root.params.data == "First notification before lock": nonlocal first_notification_received first_notification_received = True @@ -1114,7 +1275,7 @@ async def on_resumption_token_update(token: str) -> None: captured_resumption_token = token # First, start the client session and begin the tool that waits on lock - async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + async with streamable_http_client(f"{server_url}/mcp", terminate_on_close=False) as ( read_stream, write_stream, get_session_id, @@ -1157,55 +1318,57 @@ async def run_tool(): tg.cancel_scope.cancel() # Verify we received exactly one notification - assert len(captured_notifications) == 1 - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) - assert captured_notifications[0].root.params.data == "First notification before lock" + assert len(captured_notifications) == 1 # pragma: no cover + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover + assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover # Clear notifications for the second phase - captured_notifications = [] + captured_notifications = [] # pragma: no cover # Now resume the session with the same mcp-session-id and protocol version - headers: dict[str, Any] = {} - if captured_session_id: + headers: dict[str, Any] = {} # pragma: no cover + if captured_session_id: # pragma: no cover headers[MCP_SESSION_ID_HEADER] = captured_session_id - if captured_protocol_version: + if captured_protocol_version: # pragma: no cover headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version - async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( - read_stream, - write_stream, - _, - ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="release_lock", arguments={}), - ) - ), - types.CallToolResult, - ) - metadata = ClientMessageMetadata( - resumption_token=captured_resumption_token, - ) - result = await session.send_request( - types.ClientRequest( - types.CallToolRequest( - params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), - ) - ), - types.CallToolResult, - metadata=metadata, - ) - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert result.content[0].text == "Completed" + async with create_mcp_http_client(headers=headers) as httpx_client: + async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams(name="release_lock", arguments={}), + ) + ), + types.CallToolResult, + ) + metadata = ClientMessageMetadata( + resumption_token=captured_resumption_token, + ) - # We should have received the remaining notifications - assert len(captured_notifications) == 1 + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Completed" - assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) - assert captured_notifications[0].root.params.data == "Second notification after lock" + # We should have received the remaining notifications + assert len(captured_notifications) == 1 + + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover + assert captured_notifications[0].root.params.data == "Second notification after lock" # pragma: no cover @pytest.mark.anyio @@ -1223,7 +1386,8 @@ async def sampling_callback( nonlocal sampling_callback_invoked, captured_message_params sampling_callback_invoked = True captured_message_params = params - message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None + msg_content = params.messages[0].content_as_list[0] + message_received = msg_content.text if msg_content.type == "text" else None return types.CreateMessageResult( role="assistant", @@ -1236,7 +1400,7 @@ async def sampling_callback( ) # Create client with sampling callback - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -1266,7 +1430,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -class ContextAwareServerTest(Server): +class ContextAwareServerTest(Server): # pragma: no cover def __init__(self): super().__init__("ContextAwareServer") @@ -1326,7 +1490,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] # Server runner for context-aware testing -def run_context_aware_server(port: int): +def run_context_aware_server(port: int): # pragma: no cover """Run the context-aware test server.""" server = ContextAwareServerTest() @@ -1362,24 +1526,13 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", basic_server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts") + wait_for_server(basic_server_port) yield proc.kill() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover print("Context-aware server process failed to terminate") @@ -1392,28 +1545,29 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: "X-Trace-Id": "trace-123", } - async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as ( - read_stream, - write_stream, - _, - ): - async with ClientSession(read_stream, write_stream) as session: - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "ContextAwareServer" + async with create_mcp_http_client(headers=custom_headers) as httpx_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "ContextAwareServer" - # Call the tool that echoes headers back - tool_result = await session.call_tool("echo_headers", {}) + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) + # Parse the JSON response + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) - # Verify headers were propagated - assert headers_data.get("authorization") == "Bearer test-token" - assert headers_data.get("x-custom-header") == "test-value" - assert headers_data.get("x-trace-id") == "trace-123" + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio @@ -1429,21 +1583,26 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No "Authorization": f"Bearer token-{i}", } - async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() + async with create_mcp_http_client(headers=headers) as httpx_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() - # Call the tool that echoes context - tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) + # Call the tool that echoes context + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - context_data = json.loads(tool_result.content[0].text) - contexts.append(context_data) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + context_data = json.loads(tool_result.content[0].text) + contexts.append(context_data) # Verify each request had its own context - assert len(contexts) == 3 - for i, ctx in enumerate(contexts): + assert len(contexts) == 3 # pragma: no cover + for i, ctx in enumerate(contexts): # pragma: no cover assert ctx["request_id"] == f"request-{i}" assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" @@ -1453,7 +1612,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -1569,7 +1728,7 @@ async def test_client_crash_handled(basic_server: None, basic_server_url: str): # Simulate bad client that crashes after init async def bad_client(): """Client that triggers ClosedResourceError""" - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -1587,7 +1746,7 @@ async def bad_client(): await anyio.sleep(0.1) # Try a good client, it should still be able to connect and list tools - async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + async with streamable_http_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, _, @@ -1597,3 +1756,640 @@ async def bad_client(): assert isinstance(result, InitializeResult) tools = await session.list_tools() assert tools.tools + + +@pytest.mark.anyio +async def test_handle_sse_event_skips_empty_data(): + """Test that _handle_sse_event skips empty SSE data (keep-alive pings).""" + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") + + # Create a mock SSE event with empty data (keep-alive ping) + mock_sse = ServerSentEvent(event="message", data="", id=None, retry=None) + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + try: + # Call _handle_sse_event with empty data - should return False and not raise + result = await transport._handle_sse_event(mock_sse, write_stream) + + # Should return False (not complete) for empty data + assert result is False + + # Nothing should have been written to the stream + # Check buffer is empty (statistics().current_buffer_used returns buffer size) + assert write_stream.statistics().current_buffer_used == 0 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_priming_event_not_sent_for_old_protocol_version(): + """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" + # Create a transport with an event store + transport = StreamableHTTPServerTransport( + "/mcp", + event_store=SimpleEventStore(), + ) + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) + + try: + # Call _maybe_send_priming_event with OLD protocol version - should NOT send + await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-06-18") + + # Nothing should have been written to the stream + assert write_stream.statistics().current_buffer_used == 0 + + # Now test with NEW protocol version - should send + await transport._maybe_send_priming_event("test-request-id-2", write_stream, "2025-11-25") + + # Should have written a priming event + assert write_stream.statistics().current_buffer_used == 1 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_priming_event_not_sent_without_event_store(): + """Test that _maybe_send_priming_event returns early when no event_store is configured.""" + # Create a transport WITHOUT an event store + transport = StreamableHTTPServerTransport("/mcp") + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) + + try: + # Call _maybe_send_priming_event - should return early without sending + await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25") + + # Nothing should have been written to the stream + assert write_stream.statistics().current_buffer_used == 0 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_priming_event_includes_retry_interval(): + """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" + # Create a transport with an event store AND retry_interval + transport = StreamableHTTPServerTransport( + "/mcp", + event_store=SimpleEventStore(), + retry_interval=5000, + ) + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) + + try: + # Call _maybe_send_priming_event with new protocol version + await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25") + + # Should have written a priming event with retry field + assert write_stream.statistics().current_buffer_used == 1 + + # Read the event and verify it has retry field + event = await read_stream.receive() + assert "retry" in event + assert event["retry"] == 5000 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): + """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" + # Create a transport with an event store + transport = StreamableHTTPServerTransport( + "/mcp", + event_store=SimpleEventStore(), + ) + + # Create a mock message and request + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list")) + mock_request = MagicMock() + + # Call _create_session_message with OLD protocol version + session_msg = transport._create_session_message(mock_message, mock_request, "test-request-id", "2025-06-18") + + # Callbacks should NOT be provided for old protocol version + assert session_msg.metadata is not None + assert isinstance(session_msg.metadata, ServerMessageMetadata) + assert session_msg.metadata.close_sse_stream is None + assert session_msg.metadata.close_standalone_sse_stream is None + + # Now test with NEW protocol version - should provide callbacks + session_msg_new = transport._create_session_message(mock_message, mock_request, "test-request-id-2", "2025-11-25") + + # Callbacks SHOULD be provided for new protocol version + assert session_msg_new.metadata is not None + assert isinstance(session_msg_new.metadata, ServerMessageMetadata) + assert session_msg_new.metadata.close_sse_stream is not None + assert session_msg_new.metadata.close_standalone_sse_stream is not None + + +@pytest.mark.anyio +async def test_streamable_http_client_receives_priming_event( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Client should receive priming event (resumption token update) on POST SSE stream.""" + _, server_url = event_server + + captured_resumption_tokens: list[str] = [] + + async def on_resumption_token_update(token: str) -> None: + captured_resumption_tokens.append(token) + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call tool with resumption token callback via send_request + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + params=types.CallToolRequestParams(name="test_tool", arguments={}), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + assert result is not None + + # Should have received priming event token BEFORE response data + # Priming event = 1 token (empty data, id only) + # Response = 1 token (actual JSON-RPC response) + # Total = 2 tokens minimum + assert len(captured_resumption_tokens) >= 2, ( + f"Server must send priming event before response. " + f"Expected >= 2 tokens (priming + response), got {len(captured_resumption_tokens)}" + ) + assert captured_resumption_tokens[0] is not None + + +@pytest.mark.anyio +async def test_server_close_sse_stream_via_context( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Server tool can call ctx.close_sse_stream() to close connection.""" + _, server_url = event_server + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call tool that closes stream mid-operation + # This should NOT raise NotImplementedError when fully implemented + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should still receive complete response (via auto-reconnect) + assert result is not None + assert len(result.content) > 0 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + +@pytest.mark.anyio +async def test_streamable_http_client_auto_reconnects( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Client should auto-reconnect with Last-Event-ID when server closes after priming event.""" + _, server_url = event_server + captured_notifications: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): # pragma: no branch + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch + captured_notifications.append(str(message.root.params.data)) + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Call tool that: + # 1. Sends notification + # 2. Closes SSE stream + # 3. Sends more notifications (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_stream_close", {}) + + # Client should have auto-reconnected and received ALL notifications + assert len(captured_notifications) >= 2, ( + "Client should auto-reconnect and receive notifications sent both before and after stream close" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + +@pytest.mark.anyio +async def test_streamable_http_client_respects_retry_interval( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Client MUST respect retry field, waiting specified ms before reconnecting.""" + _, server_url = event_server + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + start_time = time.monotonic() + result = await session.call_tool("tool_with_stream_close", {}) + elapsed = time.monotonic() - start_time + + # Verify result was received + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + # The elapsed time should include at least the retry interval + # if reconnection occurred. This test may be flaky depending on + # implementation details, but demonstrates the expected behavior. + # Note: This assertion may need adjustment based on actual implementation + assert elapsed >= 0.4, f"Client should wait ~500ms before reconnecting, but elapsed time was {elapsed:.3f}s" + + +@pytest.mark.anyio +async def test_streamable_http_sse_polling_full_cycle( + event_server: tuple[SimpleEventStore, str], +) -> None: + """End-to-end test: server closes stream, client reconnects, receives all events.""" + _, server_url = event_server + all_notifications: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): # pragma: no branch + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch + all_notifications.append(str(message.root.params.data)) + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Call tool that simulates polling pattern: + # 1. Server sends priming event + # 2. Server sends "Before close" notification + # 3. Server closes stream (calls close_sse_stream) + # 4. (client reconnects automatically) + # 5. Server sends "After close" notification + # 6. Server sends final response + result = await session.call_tool("tool_with_stream_close", {}) + + # Verify all notifications received in order + assert "Before close" in all_notifications, "Should receive notification sent before stream close" + assert "After close" in all_notifications, ( + "Should receive notification sent after stream close (via auto-reconnect)" + ) + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Done" + + +@pytest.mark.anyio +async def test_streamable_http_events_replayed_after_disconnect( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Events sent while client is disconnected should be replayed on reconnect.""" + _, server_url = event_server + notification_data: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): # pragma: no branch + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch + notification_data.append(str(message.root.params.data)) + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Tool sends: notification1, close_stream, notification2, notification3, response + # Client should receive all notifications even though 2&3 were sent during disconnect + result = await session.call_tool("tool_with_multiple_notifications_and_close", {}) + + assert "notification1" in notification_data, "Should receive notification1 (sent before close)" + assert "notification2" in notification_data, "Should receive notification2 (sent after close, replayed)" + assert "notification3" in notification_data, "Should receive notification3 (sent after close, replayed)" + + # Verify order: notification1 should come before notification2 and notification3 + idx1 = notification_data.index("notification1") + idx2 = notification_data.index("notification2") + idx3 = notification_data.index("notification3") + assert idx1 < idx2 < idx3, "Notifications should be received in order" + + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "All notifications sent" + + +@pytest.mark.anyio +async def test_streamable_http_multiple_reconnections( + event_server: tuple[SimpleEventStore, str], +): + """Verify multiple close_sse_stream() calls each trigger a client reconnect. + + Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure + client has time to reconnect before the next checkpoint. + + With 3 checkpoints, we expect 8 resumption tokens: + - 1 priming (initial POST connection) + - 3 notifications (checkpoint_0, checkpoint_1, checkpoint_2) + - 3 priming (one per reconnect after each close) + - 1 response + """ + _, server_url = event_server + resumption_tokens: list[str] = [] + + async def on_resumption_token(token: str) -> None: + resumption_tokens.append(token) + + async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Use send_request with metadata to track resumption tokens + metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="tool_with_multiple_stream_closes", + # retry_interval=500ms, so sleep 600ms to ensure reconnect completes + arguments={"checkpoints": 3, "sleep_time": 0.6}, + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert "Completed 3 checkpoints" in result.content[0].text + + # 4 priming + 3 notifications + 1 response = 8 tokens + assert len(resumption_tokens) == 8, ( # pragma: no cover + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) + + +@pytest.mark.anyio +async def test_standalone_get_stream_reconnection( + event_server: tuple[SimpleEventStore, str], +) -> None: + """ + Test that standalone GET stream automatically reconnects after server closes it. + + Verifies: + 1. Client receives notification 1 via GET stream + 2. Server closes GET stream + 3. Client reconnects with Last-Event-ID + 4. Client receives notification 2 on new connection + + Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + callback is only provided when event_store is configured and protocol version >= 2025-11-25. + """ + _, server_url = event_server + received_notifications: list[str] = [] + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + return # pragma: no cover + if isinstance(message, types.ServerNotification): # pragma: no branch + if isinstance(message.root, types.ResourceUpdatedNotification): # pragma: no branch + received_notifications.append(str(message.root.params.uri)) + + async with streamable_http_client(f"{server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + ) as session: + await session.initialize() + + # Call tool that: + # 1. Sends notification_1 via GET stream + # 2. Closes standalone GET stream + # 3. Sends notification_2 (stored in event_store) + # 4. Returns response + result = await session.call_tool("tool_with_standalone_stream_close", {}) + + # Verify the tool completed + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Standalone stream close test done" + + # Verify both notifications were received + assert "http://notification_1/" in received_notifications, ( + f"Should receive notification 1 (sent before GET stream close), got: {received_notifications}" + ) + assert "http://notification_2/" in received_notifications, ( + f"Should receive notification 2 after reconnect, got: {received_notifications}" + ) + + +@pytest.mark.anyio +async def test_streamable_http_client_does_not_mutate_provided_client( + basic_server: None, basic_server_url: str +) -> None: + """Test that streamable_http_client does not mutate the provided httpx client's headers.""" + # Create a client with custom headers + original_headers = { + "X-Custom-Header": "custom-value", + "Authorization": "Bearer test-token", + } + + async with httpx.AsyncClient(headers=original_headers, follow_redirects=True) as custom_client: + # Use the client with streamable_http_client + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=custom_client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Verify client headers were not mutated with MCP protocol headers + # If accept header exists, it should still be httpx default, not MCP's + if "accept" in custom_client.headers: # pragma: no branch + assert custom_client.headers.get("accept") == "*/*" + # MCP content-type should not have been added + assert custom_client.headers.get("content-type") != "application/json" + + # Verify custom headers are still present and unchanged + assert custom_client.headers.get("X-Custom-Header") == "custom-value" + assert custom_client.headers.get("Authorization") == "Bearer test-token" + + +@pytest.mark.anyio +async def test_streamable_http_client_mcp_headers_override_defaults( + context_aware_server: None, basic_server_url: str +) -> None: + """Test that MCP protocol headers override httpx.AsyncClient default headers.""" + # httpx.AsyncClient has default "accept: */*" header + # We need to verify that our MCP accept header overrides it in actual requests + + async with httpx.AsyncClient(follow_redirects=True) as client: + # Verify client has default accept header + assert client.headers.get("accept") == "*/*" + + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # Use echo_headers tool to see what headers the server actually received + tool_result = await session.call_tool("echo_headers", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify MCP protocol headers were sent (not httpx defaults) + assert "accept" in headers_data + assert "application/json" in headers_data["accept"] + assert "text/event-stream" in headers_data["accept"] + + assert "content-type" in headers_data + assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_streamable_http_client_preserves_custom_with_mcp_headers( + context_aware_server: None, basic_server_url: str +) -> None: + """Test that both custom headers and MCP protocol headers are sent in requests.""" + custom_headers = { + "X-Custom-Header": "custom-value", + "X-Request-Id": "req-123", + "Authorization": "Bearer test-token", + } + + async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # Use echo_headers tool to verify both custom and MCP headers are present + tool_result = await session.call_tool("echo_headers", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify custom headers are present + assert headers_data.get("x-custom-header") == "custom-value" + assert headers_data.get("x-request-id") == "req-123" + assert headers_data.get("authorization") == "Bearer test-token" + + # Verify MCP protocol headers are also present + assert "accept" in headers_data + assert "application/json" in headers_data["accept"] + assert "text/event-stream" in headers_data["accept"] + + assert "content-type" in headers_data + assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_streamable_http_transport_deprecated_params_ignored(basic_server: None, basic_server_url: str) -> None: + """Test that deprecated parameters passed to StreamableHTTPTransport are properly ignored.""" + with pytest.warns(DeprecationWarning): + transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated] + url=f"{basic_server_url}/mcp", + headers={"X-Should-Be-Ignored": "ignored"}, + timeout=999.0, + sse_read_timeout=999.0, + auth=None, + ) + + headers = transport._prepare_headers() + assert "X-Should-Be-Ignored" not in headers + assert headers["accept"] == "application/json, text/event-stream" + assert headers["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_streamablehttp_client_deprecation_warning(basic_server: None, basic_server_url: str) -> None: + """Test that the old streamablehttp_client() function issues a deprecation warning.""" + with pytest.warns(DeprecationWarning, match="Use `streamable_http_client` instead"): + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( # pyright: ignore[reportDeprecated] + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + tools = await session.list_tools() + assert len(tools.tools) > 0 diff --git a/tests/shared/test_tool_name_validation.py b/tests/shared/test_tool_name_validation.py new file mode 100644 index 0000000000..4746f3f9f8 --- /dev/null +++ b/tests/shared/test_tool_name_validation.py @@ -0,0 +1,199 @@ +"""Tests for tool name validation utilities (SEP-986).""" + +import logging + +import pytest + +from mcp.shared.tool_name_validation import ( + issue_tool_name_warning, + validate_and_warn_tool_name, + validate_tool_name, +) + + +class TestValidateToolName: + """Tests for validate_tool_name function.""" + + class TestValidNames: + """Test cases for valid tool names.""" + + @pytest.mark.parametrize( + "tool_name", + [ + "getUser", + "get_user_profile", + "user-profile-update", + "admin.tools.list", + "DATA_EXPORT_v2.1", + "a", + "a" * 128, + ], + ids=[ + "simple_alphanumeric", + "with_underscores", + "with_dashes", + "with_dots", + "mixed_characters", + "single_character", + "max_length_128", + ], + ) + def test_accepts_valid_names(self, tool_name: str) -> None: + """Valid tool names should pass validation with no warnings.""" + result = validate_tool_name(tool_name) + assert result.is_valid is True + assert result.warnings == [] + + class TestInvalidNames: + """Test cases for invalid tool names.""" + + def test_rejects_empty_name(self) -> None: + """Empty names should be rejected.""" + result = validate_tool_name("") + assert result.is_valid is False + assert "Tool name cannot be empty" in result.warnings + + def test_rejects_name_exceeding_max_length(self) -> None: + """Names exceeding 128 characters should be rejected.""" + result = validate_tool_name("a" * 129) + assert result.is_valid is False + assert any("exceeds maximum length of 128 characters (current: 129)" in w for w in result.warnings) + + @pytest.mark.parametrize( + "tool_name,expected_char", + [ + ("get user profile", "' '"), + ("get,user,profile", "','"), + ("user/profile/update", "'/'"), + ("user@domain.com", "'@'"), + ], + ids=[ + "with_spaces", + "with_commas", + "with_slashes", + "with_at_symbol", + ], + ) + def test_rejects_invalid_characters(self, tool_name: str, expected_char: str) -> None: + """Names with invalid characters should be rejected.""" + result = validate_tool_name(tool_name) + assert result.is_valid is False + assert any("invalid characters" in w and expected_char in w for w in result.warnings) + + def test_rejects_multiple_invalid_chars(self) -> None: + """Names with multiple invalid chars should list all of them.""" + result = validate_tool_name("user name@domain,com") + assert result.is_valid is False + warning = next(w for w in result.warnings if "invalid characters" in w) + assert "' '" in warning + assert "'@'" in warning + assert "','" in warning + + def test_rejects_unicode_characters(self) -> None: + """Names with unicode characters should be rejected.""" + result = validate_tool_name("user-\u00f1ame") # n with tilde + assert result.is_valid is False + + class TestWarningsForProblematicPatterns: + """Test cases for valid names that generate warnings.""" + + def test_warns_on_leading_dash(self) -> None: + """Names starting with dash should generate warning but be valid.""" + result = validate_tool_name("-get-user") + assert result.is_valid is True + assert any("starts or ends with a dash" in w for w in result.warnings) + + def test_warns_on_trailing_dash(self) -> None: + """Names ending with dash should generate warning but be valid.""" + result = validate_tool_name("get-user-") + assert result.is_valid is True + assert any("starts or ends with a dash" in w for w in result.warnings) + + def test_warns_on_leading_dot(self) -> None: + """Names starting with dot should generate warning but be valid.""" + result = validate_tool_name(".get.user") + assert result.is_valid is True + assert any("starts or ends with a dot" in w for w in result.warnings) + + def test_warns_on_trailing_dot(self) -> None: + """Names ending with dot should generate warning but be valid.""" + result = validate_tool_name("get.user.") + assert result.is_valid is True + assert any("starts or ends with a dot" in w for w in result.warnings) + + +class TestIssueToolNameWarning: + """Tests for issue_tool_name_warning function.""" + + def test_logs_warnings(self, caplog: pytest.LogCaptureFixture) -> None: + """Warnings should be logged at WARNING level.""" + warnings = ["Warning 1", "Warning 2"] + with caplog.at_level(logging.WARNING): + issue_tool_name_warning("test-tool", warnings) + + assert 'Tool name validation warning for "test-tool"' in caplog.text + assert "- Warning 1" in caplog.text + assert "- Warning 2" in caplog.text + assert "Tool registration will proceed" in caplog.text + assert "SEP-986" in caplog.text + + def test_no_logging_for_empty_warnings(self, caplog: pytest.LogCaptureFixture) -> None: + """Empty warnings list should not produce any log output.""" + with caplog.at_level(logging.WARNING): + issue_tool_name_warning("test-tool", []) + + assert caplog.text == "" + + +class TestValidateAndWarnToolName: + """Tests for validate_and_warn_tool_name function.""" + + def test_returns_true_for_valid_name(self) -> None: + """Valid names should return True.""" + assert validate_and_warn_tool_name("valid-tool-name") is True + + def test_returns_false_for_invalid_name(self) -> None: + """Invalid names should return False.""" + assert validate_and_warn_tool_name("") is False + assert validate_and_warn_tool_name("a" * 129) is False + assert validate_and_warn_tool_name("invalid name") is False + + def test_logs_warnings_for_invalid_name(self, caplog: pytest.LogCaptureFixture) -> None: + """Invalid names should trigger warning logs.""" + with caplog.at_level(logging.WARNING): + validate_and_warn_tool_name("invalid name") + + assert "Tool name validation warning" in caplog.text + + def test_no_warnings_for_clean_valid_name(self, caplog: pytest.LogCaptureFixture) -> None: + """Clean valid names should not produce any log output.""" + with caplog.at_level(logging.WARNING): + result = validate_and_warn_tool_name("clean-tool-name") + + assert result is True + assert caplog.text == "" + + +class TestEdgeCases: + """Test edge cases and robustness.""" + + @pytest.mark.parametrize( + "tool_name,is_valid,expected_warning_fragment", + [ + ("...", True, "starts or ends with a dot"), + ("---", True, "starts or ends with a dash"), + ("///", False, "invalid characters"), + ("user@name123", False, "invalid characters"), + ], + ids=[ + "only_dots", + "only_dashes", + "only_slashes", + "mixed_valid_invalid", + ], + ) + def test_edge_cases(self, tool_name: str, is_valid: bool, expected_warning_fragment: str) -> None: + """Various edge cases should be handled correctly.""" + result = validate_tool_name(tool_name) + assert result.is_valid is is_valid + assert any(expected_warning_fragment in w for w in result.warnings) diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 2d67eccdd0..f093cb4927 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -26,6 +26,7 @@ TextResourceContents, Tool, ) +from tests.test_helpers import wait_for_server SERVER_NAME = "test_server_for_WS" @@ -43,7 +44,7 @@ def server_url(http://github.com/server_port:%20int) -> str: # Test server implementation -class ServerTest(Server): +class ServerTest(Server): # pragma: no cover def __init__(self): super().__init__(SERVER_NAME) @@ -74,7 +75,7 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] # Test fixtures -def make_server_app() -> Starlette: +def make_server_app() -> Starlette: # pragma: no cover """Create test Starlette app with WebSocket transport""" server = ServerTest() @@ -91,7 +92,7 @@ async def handle_ws(websocket: WebSocket): return app -def run_server(server_port: int) -> None: +def run_server(server_port: int) -> None: # pragma: no cover app = make_server_app() server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") @@ -103,26 +104,15 @@ def run_server(server_port: int) -> None: time.sleep(0.5) -@pytest.fixture() +@pytest.fixture() # pragma: no cover def server(server_port: int) -> Generator[None, None, None]: proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() # Wait for server to be running - max_attempts = 20 - attempt = 0 print("waiting for server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + wait_for_server(server_port) yield @@ -130,7 +120,7 @@ def server(server_port: int) -> Generator[None, None, None]: # Signal the server to stop proc.kill() proc.join(timeout=2) - if proc.is_alive(): + if proc.is_alive(): # pragma: no cover print("server process failed to terminate") diff --git a/tests/test_examples.py b/tests/test_examples.py index 59063f122f..6f5464e394 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -44,6 +44,23 @@ async def test_complex_inputs(): assert result.content[2].text == "charlie" +@pytest.mark.anyio +async def test_direct_call_tool_result_return(): + """Test the CallToolResult echo server""" + from examples.fastmcp.direct_call_tool_result_return import mcp + + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("echo", {"text": "hello"}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "hello" + assert result.structuredContent + assert result.structuredContent["text"] == "hello" + assert isinstance(result.meta, dict) + assert result.meta["some"] == "metadata" + + @pytest.mark.anyio async def test_desktop(monkeypatch: pytest.MonkeyPatch): """Test the desktop server""" @@ -72,13 +89,13 @@ async def test_desktop(monkeypatch: pytest.MonkeyPatch): content = result.contents[0] assert isinstance(content, TextResourceContents) assert isinstance(content.text, str) - if sys.platform == "win32": + if sys.platform == "win32": # pragma: no cover file_1 = "/fake/path/file1.txt".replace("/", "\\\\") # might be a bug file_2 = "/fake/path/file2.txt".replace("/", "\\\\") # might be a bug assert file_1 in content.text assert file_2 in content.text # might be a bug, but the test is passing - else: + else: # pragma: no cover assert "/fake/path/file1.txt" in content.text assert "/fake/path/file2.txt" in content.text diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000000..5c04c269ff --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,31 @@ +"""Common test utilities for MCP server tests.""" + +import socket +import time + + +def wait_for_server(port: int, timeout: float = 20.0) -> None: + """Wait for server to be ready to accept connections. + + Polls the server port until it accepts connections or timeout is reached. + This eliminates race conditions without arbitrary sleeps. + + Args: + port: The port number to check + timeout: Maximum time to wait in seconds (default 5.0) + + Raises: + TimeoutError: If server doesn't start within the timeout period + """ + start_time = time.time() + while time.time() - start_time < timeout: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(0.1) + s.connect(("127.0.0.1", port)) + # Server is ready + return + except (ConnectionRefusedError, OSError): + # Server not ready yet, retry quickly + time.sleep(0.01) + raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover diff --git a/tests/test_types.py b/tests/test_types.py index 415eba66a7..1c16c3cc6e 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,14 +1,27 @@ +from typing import Any + import pytest from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, ClientRequest, + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, Implementation, InitializeRequest, InitializeRequestParams, JSONRPCMessage, JSONRPCRequest, + ListToolsResult, + SamplingCapability, + SamplingMessage, + TextContent, + Tool, + ToolChoice, + ToolResultContent, + ToolUseContent, ) @@ -56,3 +69,295 @@ async def test_method_initialization(): assert initialize_request.method == "initialize", "method should be set to 'initialize'" assert initialize_request.params is not None assert initialize_request.params.protocolVersion == LATEST_PROTOCOL_VERSION + + +@pytest.mark.anyio +async def test_tool_use_content(): + """Test ToolUseContent type for SEP-1577.""" + tool_use_data = { + "type": "tool_use", + "name": "get_weather", + "id": "call_abc123", + "input": {"location": "San Francisco", "unit": "celsius"}, + } + + tool_use = ToolUseContent.model_validate(tool_use_data) + assert tool_use.type == "tool_use" + assert tool_use.name == "get_weather" + assert tool_use.id == "call_abc123" + assert tool_use.input == {"location": "San Francisco", "unit": "celsius"} + + # Test serialization + serialized = tool_use.model_dump(by_alias=True, exclude_none=True) + assert serialized["type"] == "tool_use" + assert serialized["name"] == "get_weather" + + +@pytest.mark.anyio +async def test_tool_result_content(): + """Test ToolResultContent type for SEP-1577.""" + tool_result_data = { + "type": "tool_result", + "toolUseId": "call_abc123", + "content": [{"type": "text", "text": "It's 72°F in San Francisco"}], + "isError": False, + } + + tool_result = ToolResultContent.model_validate(tool_result_data) + assert tool_result.type == "tool_result" + assert tool_result.toolUseId == "call_abc123" + assert len(tool_result.content) == 1 + assert tool_result.isError is False + + # Test with empty content (should default to []) + minimal_result_data = {"type": "tool_result", "toolUseId": "call_xyz"} + minimal_result = ToolResultContent.model_validate(minimal_result_data) + assert minimal_result.content == [] + + +@pytest.mark.anyio +async def test_tool_choice(): + """Test ToolChoice type for SEP-1577.""" + # Test with mode + tool_choice_data = {"mode": "required"} + tool_choice = ToolChoice.model_validate(tool_choice_data) + assert tool_choice.mode == "required" + + # Test with minimal data (all fields optional) + minimal_choice = ToolChoice.model_validate({}) + assert minimal_choice.mode is None + + # Test different modes + auto_choice = ToolChoice.model_validate({"mode": "auto"}) + assert auto_choice.mode == "auto" + + none_choice = ToolChoice.model_validate({"mode": "none"}) + assert none_choice.mode == "none" + + +@pytest.mark.anyio +async def test_sampling_message_with_user_role(): + """Test SamplingMessage with user role for SEP-1577.""" + # Test with single content + user_msg_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} + user_msg = SamplingMessage.model_validate(user_msg_data) + assert user_msg.role == "user" + assert isinstance(user_msg.content, TextContent) + + # Test with array of content including tool result + multi_content_data: dict[str, Any] = { + "role": "user", + "content": [ + {"type": "text", "text": "Here's the result:"}, + {"type": "tool_result", "toolUseId": "call_123", "content": []}, + ], + } + multi_msg = SamplingMessage.model_validate(multi_content_data) + assert multi_msg.role == "user" + assert isinstance(multi_msg.content, list) + assert len(multi_msg.content) == 2 + + +@pytest.mark.anyio +async def test_sampling_message_with_assistant_role(): + """Test SamplingMessage with assistant role for SEP-1577.""" + # Test with tool use content + assistant_msg_data = { + "role": "assistant", + "content": { + "type": "tool_use", + "name": "search", + "id": "call_456", + "input": {"query": "MCP protocol"}, + }, + } + assistant_msg = SamplingMessage.model_validate(assistant_msg_data) + assert assistant_msg.role == "assistant" + assert isinstance(assistant_msg.content, ToolUseContent) + + # Test with array of mixed content + multi_content_data: dict[str, Any] = { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search for that..."}, + {"type": "tool_use", "name": "search", "id": "call_789", "input": {}}, + ], + } + multi_msg = SamplingMessage.model_validate(multi_content_data) + assert isinstance(multi_msg.content, list) + assert len(multi_msg.content) == 2 + + +@pytest.mark.anyio +async def test_sampling_message_backward_compatibility(): + """Test that SamplingMessage maintains backward compatibility.""" + # Old-style message (single content, no tools) + old_style_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} + old_msg = SamplingMessage.model_validate(old_style_data) + assert old_msg.role == "user" + assert isinstance(old_msg.content, TextContent) + + # New-style message with tool content + new_style_data: dict[str, Any] = { + "role": "assistant", + "content": {"type": "tool_use", "name": "test", "id": "call_1", "input": {}}, + } + new_msg = SamplingMessage.model_validate(new_style_data) + assert new_msg.role == "assistant" + assert isinstance(new_msg.content, ToolUseContent) + + # Array content + array_style_data: dict[str, Any] = { + "role": "user", + "content": [{"type": "text", "text": "Result:"}, {"type": "tool_result", "toolUseId": "call_1", "content": []}], + } + array_msg = SamplingMessage.model_validate(array_style_data) + assert isinstance(array_msg.content, list) + + +@pytest.mark.anyio +async def test_create_message_request_params_with_tools(): + """Test CreateMessageRequestParams with tools for SEP-1577.""" + tool = Tool( + name="get_weather", + description="Get weather information", + inputSchema={"type": "object", "properties": {"location": {"type": "string"}}}, + ) + + params = CreateMessageRequestParams( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="What's the weather?"))], + maxTokens=1000, + tools=[tool], + toolChoice=ToolChoice(mode="auto"), + ) + + assert params.tools is not None + assert len(params.tools) == 1 + assert params.tools[0].name == "get_weather" + assert params.toolChoice is not None + assert params.toolChoice.mode == "auto" + + +@pytest.mark.anyio +async def test_create_message_result_with_tool_use(): + """Test CreateMessageResultWithTools with tool use content for SEP-1577.""" + result_data = { + "role": "assistant", + "content": {"type": "tool_use", "name": "search", "id": "call_123", "input": {"query": "test"}}, + "model": "claude-3", + "stopReason": "toolUse", + } + + # Tool use content uses CreateMessageResultWithTools + result = CreateMessageResultWithTools.model_validate(result_data) + assert result.role == "assistant" + assert isinstance(result.content, ToolUseContent) + assert result.stopReason == "toolUse" + assert result.model == "claude-3" + + # Test content_as_list with single content (covers else branch) + content_list = result.content_as_list + assert len(content_list) == 1 + assert content_list[0] == result.content + + +@pytest.mark.anyio +async def test_create_message_result_basic(): + """Test CreateMessageResult with basic text content (backwards compatible).""" + result_data = { + "role": "assistant", + "content": {"type": "text", "text": "Hello!"}, + "model": "claude-3", + "stopReason": "endTurn", + } + + # Basic content uses CreateMessageResult (single content, no arrays) + result = CreateMessageResult.model_validate(result_data) + assert result.role == "assistant" + assert isinstance(result.content, TextContent) + assert result.content.text == "Hello!" + assert result.stopReason == "endTurn" + assert result.model == "claude-3" + + +@pytest.mark.anyio +async def test_client_capabilities_with_sampling_tools(): + """Test ClientCapabilities with nested sampling capabilities for SEP-1577.""" + # New structured format + capabilities_data: dict[str, Any] = { + "sampling": {"tools": {}}, + } + capabilities = ClientCapabilities.model_validate(capabilities_data) + assert capabilities.sampling is not None + assert isinstance(capabilities.sampling, SamplingCapability) + assert capabilities.sampling.tools is not None + + # With both context and tools + full_capabilities_data: dict[str, Any] = {"sampling": {"context": {}, "tools": {}}} + full_caps = ClientCapabilities.model_validate(full_capabilities_data) + assert isinstance(full_caps.sampling, SamplingCapability) + assert full_caps.sampling.context is not None + assert full_caps.sampling.tools is not None + + +def test_tool_preserves_json_schema_2020_12_fields(): + """Verify that JSON Schema 2020-12 keywords are preserved in Tool.inputSchema. + + SEP-1613 establishes JSON Schema 2020-12 as the default dialect for MCP. + This test ensures the SDK doesn't strip $schema, $defs, or additionalProperties. + """ + input_schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + } + }, + "properties": { + "name": {"type": "string"}, + "address": {"$ref": "#/$defs/address"}, + }, + "additionalProperties": False, + } + + tool = Tool(name="test_tool", description="A test tool", inputSchema=input_schema) + + # Verify fields are preserved in the model + assert tool.inputSchema["$schema"] == "https://json-schema.org/draft/2020-12/schema" + assert "$defs" in tool.inputSchema + assert "address" in tool.inputSchema["$defs"] + assert tool.inputSchema["additionalProperties"] is False + + # Verify fields survive serialization round-trip + serialized = tool.model_dump(mode="json", by_alias=True) + assert serialized["inputSchema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema" + assert "$defs" in serialized["inputSchema"] + assert serialized["inputSchema"]["additionalProperties"] is False + + +def test_list_tools_result_preserves_json_schema_2020_12_fields(): + """Verify JSON Schema 2020-12 fields survive ListToolsResult deserialization.""" + raw_response = { + "tools": [ + { + "name": "json_schema_tool", + "description": "Tool with JSON Schema 2020-12 features", + "inputSchema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"item": {"type": "string"}}, + "properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/item"}}}, + "additionalProperties": False, + }, + } + ] + } + + result = ListToolsResult.model_validate(raw_response) + tool = result.tools[0] + + assert tool.inputSchema["$schema"] == "https://json-schema.org/draft/2020-12/schema" + assert "$defs" in tool.inputSchema + assert tool.inputSchema["additionalProperties"] is False diff --git a/uv.lock b/uv.lock index 68abdcc4f5..757709acdf 100644 --- a/uv.lock +++ b/uv.lock @@ -1,18 +1,29 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" [manifest] members = [ "mcp", + "mcp-conformance-auth-client", + "mcp-everything-server", "mcp-simple-auth", + "mcp-simple-auth-client", + "mcp-simple-chatbot", "mcp-simple-pagination", "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", "mcp-simple-streamablehttp-stateless", + "mcp-simple-task", + "mcp-simple-task-client", + "mcp-simple-task-interactive", + "mcp-simple-task-interactive-client", "mcp-simple-tool", "mcp-snippets", + "mcp-sse-polling-client", + "mcp-sse-polling-demo", + "mcp-structured-output-lowlevel", ] [[package]] @@ -318,6 +329,157 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "coverage" +version = "7.10.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/26/d22c300112504f5f9a9fd2297ce33c35f3d353e4aeb987c8419453b2a7c2/coverage-7.10.7.tar.gz", hash = "sha256:f4ab143ab113be368a3e9b795f9cd7906c5ef407d6173fe9675a902e1fffc239", size = 827704, upload-time = "2025-09-21T20:03:56.815Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/6c/3a3f7a46888e69d18abe3ccc6fe4cb16cccb1e6a2f99698931dafca489e6/coverage-7.10.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fc04cc7a3db33664e0c2d10eb8990ff6b3536f6842c9590ae8da4c614b9ed05a", size = 217987, upload-time = "2025-09-21T20:00:57.218Z" }, + { url = "https://files.pythonhosted.org/packages/03/94/952d30f180b1a916c11a56f5c22d3535e943aa22430e9e3322447e520e1c/coverage-7.10.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e201e015644e207139f7e2351980feb7040e6f4b2c2978892f3e3789d1c125e5", size = 218388, upload-time = "2025-09-21T20:01:00.081Z" }, + { url = "https://files.pythonhosted.org/packages/50/2b/9e0cf8ded1e114bcd8b2fd42792b57f1c4e9e4ea1824cde2af93a67305be/coverage-7.10.7-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:240af60539987ced2c399809bd34f7c78e8abe0736af91c3d7d0e795df633d17", size = 245148, upload-time = "2025-09-21T20:01:01.768Z" }, + { url = "https://files.pythonhosted.org/packages/19/20/d0384ac06a6f908783d9b6aa6135e41b093971499ec488e47279f5b846e6/coverage-7.10.7-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8421e088bc051361b01c4b3a50fd39a4b9133079a2229978d9d30511fd05231b", size = 246958, upload-time = "2025-09-21T20:01:03.355Z" }, + { url = "https://files.pythonhosted.org/packages/60/83/5c283cff3d41285f8eab897651585db908a909c572bdc014bcfaf8a8b6ae/coverage-7.10.7-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6be8ed3039ae7f7ac5ce058c308484787c86e8437e72b30bf5e88b8ea10f3c87", size = 248819, upload-time = "2025-09-21T20:01:04.968Z" }, + { url = "https://files.pythonhosted.org/packages/60/22/02eb98fdc5ff79f423e990d877693e5310ae1eab6cb20ae0b0b9ac45b23b/coverage-7.10.7-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e28299d9f2e889e6d51b1f043f58d5f997c373cc12e6403b90df95b8b047c13e", size = 245754, upload-time = "2025-09-21T20:01:06.321Z" }, + { url = "https://files.pythonhosted.org/packages/b4/bc/25c83bcf3ad141b32cd7dc45485ef3c01a776ca3aa8ef0a93e77e8b5bc43/coverage-7.10.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c4e16bd7761c5e454f4efd36f345286d6f7c5fa111623c355691e2755cae3b9e", size = 246860, upload-time = "2025-09-21T20:01:07.605Z" }, + { url = "https://files.pythonhosted.org/packages/3c/b7/95574702888b58c0928a6e982038c596f9c34d52c5e5107f1eef729399b5/coverage-7.10.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b1c81d0e5e160651879755c9c675b974276f135558cf4ba79fee7b8413a515df", size = 244877, upload-time = "2025-09-21T20:01:08.829Z" }, + { url = "https://files.pythonhosted.org/packages/47/b6/40095c185f235e085df0e0b158f6bd68cc6e1d80ba6c7721dc81d97ec318/coverage-7.10.7-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:606cc265adc9aaedcc84f1f064f0e8736bc45814f15a357e30fca7ecc01504e0", size = 245108, upload-time = "2025-09-21T20:01:10.527Z" }, + { url = "https://files.pythonhosted.org/packages/c8/50/4aea0556da7a4b93ec9168420d170b55e2eb50ae21b25062513d020c6861/coverage-7.10.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:10b24412692df990dbc34f8fb1b6b13d236ace9dfdd68df5b28c2e39cafbba13", size = 245752, upload-time = "2025-09-21T20:01:11.857Z" }, + { url = "https://files.pythonhosted.org/packages/6a/28/ea1a84a60828177ae3b100cb6723838523369a44ec5742313ed7db3da160/coverage-7.10.7-cp310-cp310-win32.whl", hash = "sha256:b51dcd060f18c19290d9b8a9dd1e0181538df2ce0717f562fff6cf74d9fc0b5b", size = 220497, upload-time = "2025-09-21T20:01:13.459Z" }, + { url = "https://files.pythonhosted.org/packages/fc/1a/a81d46bbeb3c3fd97b9602ebaa411e076219a150489bcc2c025f151bd52d/coverage-7.10.7-cp310-cp310-win_amd64.whl", hash = "sha256:3a622ac801b17198020f09af3eaf45666b344a0d69fc2a6ffe2ea83aeef1d807", size = 221392, upload-time = "2025-09-21T20:01:14.722Z" }, + { url = "https://files.pythonhosted.org/packages/d2/5d/c1a17867b0456f2e9ce2d8d4708a4c3a089947d0bec9c66cdf60c9e7739f/coverage-7.10.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a609f9c93113be646f44c2a0256d6ea375ad047005d7f57a5c15f614dc1b2f59", size = 218102, upload-time = "2025-09-21T20:01:16.089Z" }, + { url = "https://files.pythonhosted.org/packages/54/f0/514dcf4b4e3698b9a9077f084429681bf3aad2b4a72578f89d7f643eb506/coverage-7.10.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:65646bb0359386e07639c367a22cf9b5bf6304e8630b565d0626e2bdf329227a", size = 218505, upload-time = "2025-09-21T20:01:17.788Z" }, + { url = "https://files.pythonhosted.org/packages/20/f6/9626b81d17e2a4b25c63ac1b425ff307ecdeef03d67c9a147673ae40dc36/coverage-7.10.7-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5f33166f0dfcce728191f520bd2692914ec70fac2713f6bf3ce59c3deacb4699", size = 248898, upload-time = "2025-09-21T20:01:19.488Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ef/bd8e719c2f7417ba03239052e099b76ea1130ac0cbb183ee1fcaa58aaff3/coverage-7.10.7-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:35f5e3f9e455bb17831876048355dca0f758b6df22f49258cb5a91da23ef437d", size = 250831, upload-time = "2025-09-21T20:01:20.817Z" }, + { url = "https://files.pythonhosted.org/packages/a5/b6/bf054de41ec948b151ae2b79a55c107f5760979538f5fb80c195f2517718/coverage-7.10.7-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4da86b6d62a496e908ac2898243920c7992499c1712ff7c2b6d837cc69d9467e", size = 252937, upload-time = "2025-09-21T20:01:22.171Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e5/3860756aa6f9318227443c6ce4ed7bf9e70bb7f1447a0353f45ac5c7974b/coverage-7.10.7-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6b8b09c1fad947c84bbbc95eca841350fad9cbfa5a2d7ca88ac9f8d836c92e23", size = 249021, upload-time = "2025-09-21T20:01:23.907Z" }, + { url = "https://files.pythonhosted.org/packages/26/0f/bd08bd042854f7fd07b45808927ebcce99a7ed0f2f412d11629883517ac2/coverage-7.10.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4376538f36b533b46f8971d3a3e63464f2c7905c9800db97361c43a2b14792ab", size = 250626, upload-time = "2025-09-21T20:01:25.721Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a7/4777b14de4abcc2e80c6b1d430f5d51eb18ed1d75fca56cbce5f2db9b36e/coverage-7.10.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:121da30abb574f6ce6ae09840dae322bef734480ceafe410117627aa54f76d82", size = 248682, upload-time = "2025-09-21T20:01:27.105Z" }, + { url = "https://files.pythonhosted.org/packages/34/72/17d082b00b53cd45679bad682fac058b87f011fd8b9fe31d77f5f8d3a4e4/coverage-7.10.7-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:88127d40df529336a9836870436fc2751c339fbaed3a836d42c93f3e4bd1d0a2", size = 248402, upload-time = "2025-09-21T20:01:28.629Z" }, + { url = "https://files.pythonhosted.org/packages/81/7a/92367572eb5bdd6a84bfa278cc7e97db192f9f45b28c94a9ca1a921c3577/coverage-7.10.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ba58bbcd1b72f136080c0bccc2400d66cc6115f3f906c499013d065ac33a4b61", size = 249320, upload-time = "2025-09-21T20:01:30.004Z" }, + { url = "https://files.pythonhosted.org/packages/2f/88/a23cc185f6a805dfc4fdf14a94016835eeb85e22ac3a0e66d5e89acd6462/coverage-7.10.7-cp311-cp311-win32.whl", hash = "sha256:972b9e3a4094b053a4e46832b4bc829fc8a8d347160eb39d03f1690316a99c14", size = 220536, upload-time = "2025-09-21T20:01:32.184Z" }, + { url = "https://files.pythonhosted.org/packages/fe/ef/0b510a399dfca17cec7bc2f05ad8bd78cf55f15c8bc9a73ab20c5c913c2e/coverage-7.10.7-cp311-cp311-win_amd64.whl", hash = "sha256:a7b55a944a7f43892e28ad4bc0561dfd5f0d73e605d1aa5c3c976b52aea121d2", size = 221425, upload-time = "2025-09-21T20:01:33.557Z" }, + { url = "https://files.pythonhosted.org/packages/51/7f/023657f301a276e4ba1850f82749bc136f5a7e8768060c2e5d9744a22951/coverage-7.10.7-cp311-cp311-win_arm64.whl", hash = "sha256:736f227fb490f03c6488f9b6d45855f8e0fd749c007f9303ad30efab0e73c05a", size = 220103, upload-time = "2025-09-21T20:01:34.929Z" }, + { url = "https://files.pythonhosted.org/packages/13/e4/eb12450f71b542a53972d19117ea5a5cea1cab3ac9e31b0b5d498df1bd5a/coverage-7.10.7-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7bb3b9ddb87ef7725056572368040c32775036472d5a033679d1fa6c8dc08417", size = 218290, upload-time = "2025-09-21T20:01:36.455Z" }, + { url = "https://files.pythonhosted.org/packages/37/66/593f9be12fc19fb36711f19a5371af79a718537204d16ea1d36f16bd78d2/coverage-7.10.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:18afb24843cbc175687225cab1138c95d262337f5473512010e46831aa0c2973", size = 218515, upload-time = "2025-09-21T20:01:37.982Z" }, + { url = "https://files.pythonhosted.org/packages/66/80/4c49f7ae09cafdacc73fbc30949ffe77359635c168f4e9ff33c9ebb07838/coverage-7.10.7-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:399a0b6347bcd3822be369392932884b8216d0944049ae22925631a9b3d4ba4c", size = 250020, upload-time = "2025-09-21T20:01:39.617Z" }, + { url = "https://files.pythonhosted.org/packages/a6/90/a64aaacab3b37a17aaedd83e8000142561a29eb262cede42d94a67f7556b/coverage-7.10.7-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:314f2c326ded3f4b09be11bc282eb2fc861184bc95748ae67b360ac962770be7", size = 252769, upload-time = "2025-09-21T20:01:41.341Z" }, + { url = "https://files.pythonhosted.org/packages/98/2e/2dda59afd6103b342e096f246ebc5f87a3363b5412609946c120f4e7750d/coverage-7.10.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c41e71c9cfb854789dee6fc51e46743a6d138b1803fab6cb860af43265b42ea6", size = 253901, upload-time = "2025-09-21T20:01:43.042Z" }, + { url = "https://files.pythonhosted.org/packages/53/dc/8d8119c9051d50f3119bb4a75f29f1e4a6ab9415cd1fa8bf22fcc3fb3b5f/coverage-7.10.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc01f57ca26269c2c706e838f6422e2a8788e41b3e3c65e2f41148212e57cd59", size = 250413, upload-time = "2025-09-21T20:01:44.469Z" }, + { url = "https://files.pythonhosted.org/packages/98/b3/edaff9c5d79ee4d4b6d3fe046f2b1d799850425695b789d491a64225d493/coverage-7.10.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a6442c59a8ac8b85812ce33bc4d05bde3fb22321fa8294e2a5b487c3505f611b", size = 251820, upload-time = "2025-09-21T20:01:45.915Z" }, + { url = "https://files.pythonhosted.org/packages/11/25/9a0728564bb05863f7e513e5a594fe5ffef091b325437f5430e8cfb0d530/coverage-7.10.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:78a384e49f46b80fb4c901d52d92abe098e78768ed829c673fbb53c498bef73a", size = 249941, upload-time = "2025-09-21T20:01:47.296Z" }, + { url = "https://files.pythonhosted.org/packages/e0/fd/ca2650443bfbef5b0e74373aac4df67b08180d2f184b482c41499668e258/coverage-7.10.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:5e1e9802121405ede4b0133aa4340ad8186a1d2526de5b7c3eca519db7bb89fb", size = 249519, upload-time = "2025-09-21T20:01:48.73Z" }, + { url = "https://files.pythonhosted.org/packages/24/79/f692f125fb4299b6f963b0745124998ebb8e73ecdfce4ceceb06a8c6bec5/coverage-7.10.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d41213ea25a86f69efd1575073d34ea11aabe075604ddf3d148ecfec9e1e96a1", size = 251375, upload-time = "2025-09-21T20:01:50.529Z" }, + { url = "https://files.pythonhosted.org/packages/5e/75/61b9bbd6c7d24d896bfeec57acba78e0f8deac68e6baf2d4804f7aae1f88/coverage-7.10.7-cp312-cp312-win32.whl", hash = "sha256:77eb4c747061a6af8d0f7bdb31f1e108d172762ef579166ec84542f711d90256", size = 220699, upload-time = "2025-09-21T20:01:51.941Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f3/3bf7905288b45b075918d372498f1cf845b5b579b723c8fd17168018d5f5/coverage-7.10.7-cp312-cp312-win_amd64.whl", hash = "sha256:f51328ffe987aecf6d09f3cd9d979face89a617eacdaea43e7b3080777f647ba", size = 221512, upload-time = "2025-09-21T20:01:53.481Z" }, + { url = "https://files.pythonhosted.org/packages/5c/44/3e32dbe933979d05cf2dac5e697c8599cfe038aaf51223ab901e208d5a62/coverage-7.10.7-cp312-cp312-win_arm64.whl", hash = "sha256:bda5e34f8a75721c96085903c6f2197dc398c20ffd98df33f866a9c8fd95f4bf", size = 220147, upload-time = "2025-09-21T20:01:55.2Z" }, + { url = "https://files.pythonhosted.org/packages/9a/94/b765c1abcb613d103b64fcf10395f54d69b0ef8be6a0dd9c524384892cc7/coverage-7.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:981a651f543f2854abd3b5fcb3263aac581b18209be49863ba575de6edf4c14d", size = 218320, upload-time = "2025-09-21T20:01:56.629Z" }, + { url = "https://files.pythonhosted.org/packages/72/4f/732fff31c119bb73b35236dd333030f32c4bfe909f445b423e6c7594f9a2/coverage-7.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73ab1601f84dc804f7812dc297e93cd99381162da39c47040a827d4e8dafe63b", size = 218575, upload-time = "2025-09-21T20:01:58.203Z" }, + { url = "https://files.pythonhosted.org/packages/87/02/ae7e0af4b674be47566707777db1aa375474f02a1d64b9323e5813a6cdd5/coverage-7.10.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a8b6f03672aa6734e700bbcd65ff050fd19cddfec4b031cc8cf1c6967de5a68e", size = 249568, upload-time = "2025-09-21T20:01:59.748Z" }, + { url = "https://files.pythonhosted.org/packages/a2/77/8c6d22bf61921a59bce5471c2f1f7ac30cd4ac50aadde72b8c48d5727902/coverage-7.10.7-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10b6ba00ab1132a0ce4428ff68cf50a25efd6840a42cdf4239c9b99aad83be8b", size = 252174, upload-time = "2025-09-21T20:02:01.192Z" }, + { url = "https://files.pythonhosted.org/packages/b1/20/b6ea4f69bbb52dac0aebd62157ba6a9dddbfe664f5af8122dac296c3ee15/coverage-7.10.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c79124f70465a150e89340de5963f936ee97097d2ef76c869708c4248c63ca49", size = 253447, upload-time = "2025-09-21T20:02:02.701Z" }, + { url = "https://files.pythonhosted.org/packages/f9/28/4831523ba483a7f90f7b259d2018fef02cb4d5b90bc7c1505d6e5a84883c/coverage-7.10.7-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:69212fbccdbd5b0e39eac4067e20a4a5256609e209547d86f740d68ad4f04911", size = 249779, upload-time = "2025-09-21T20:02:04.185Z" }, + { url = "https://files.pythonhosted.org/packages/a7/9f/4331142bc98c10ca6436d2d620c3e165f31e6c58d43479985afce6f3191c/coverage-7.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ea7c6c9d0d286d04ed3541747e6597cbe4971f22648b68248f7ddcd329207f0", size = 251604, upload-time = "2025-09-21T20:02:06.034Z" }, + { url = "https://files.pythonhosted.org/packages/ce/60/bda83b96602036b77ecf34e6393a3836365481b69f7ed7079ab85048202b/coverage-7.10.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b9be91986841a75042b3e3243d0b3cb0b2434252b977baaf0cd56e960fe1e46f", size = 249497, upload-time = "2025-09-21T20:02:07.619Z" }, + { url = "https://files.pythonhosted.org/packages/5f/af/152633ff35b2af63977edd835d8e6430f0caef27d171edf2fc76c270ef31/coverage-7.10.7-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:b281d5eca50189325cfe1f365fafade89b14b4a78d9b40b05ddd1fc7d2a10a9c", size = 249350, upload-time = "2025-09-21T20:02:10.34Z" }, + { url = "https://files.pythonhosted.org/packages/9d/71/d92105d122bd21cebba877228990e1646d862e34a98bb3374d3fece5a794/coverage-7.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:99e4aa63097ab1118e75a848a28e40d68b08a5e19ce587891ab7fd04475e780f", size = 251111, upload-time = "2025-09-21T20:02:12.122Z" }, + { url = "https://files.pythonhosted.org/packages/a2/9e/9fdb08f4bf476c912f0c3ca292e019aab6712c93c9344a1653986c3fd305/coverage-7.10.7-cp313-cp313-win32.whl", hash = "sha256:dc7c389dce432500273eaf48f410b37886be9208b2dd5710aaf7c57fd442c698", size = 220746, upload-time = "2025-09-21T20:02:13.919Z" }, + { url = "https://files.pythonhosted.org/packages/b1/b1/a75fd25df44eab52d1931e89980d1ada46824c7a3210be0d3c88a44aaa99/coverage-7.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:cac0fdca17b036af3881a9d2729a850b76553f3f716ccb0360ad4dbc06b3b843", size = 221541, upload-time = "2025-09-21T20:02:15.57Z" }, + { url = "https://files.pythonhosted.org/packages/14/3a/d720d7c989562a6e9a14b2c9f5f2876bdb38e9367126d118495b89c99c37/coverage-7.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:4b6f236edf6e2f9ae8fcd1332da4e791c1b6ba0dc16a2dc94590ceccb482e546", size = 220170, upload-time = "2025-09-21T20:02:17.395Z" }, + { url = "https://files.pythonhosted.org/packages/bb/22/e04514bf2a735d8b0add31d2b4ab636fc02370730787c576bb995390d2d5/coverage-7.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a0ec07fd264d0745ee396b666d47cef20875f4ff2375d7c4f58235886cc1ef0c", size = 219029, upload-time = "2025-09-21T20:02:18.936Z" }, + { url = "https://files.pythonhosted.org/packages/11/0b/91128e099035ece15da3445d9015e4b4153a6059403452d324cbb0a575fa/coverage-7.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd5e856ebb7bfb7672b0086846db5afb4567a7b9714b8a0ebafd211ec7ce6a15", size = 219259, upload-time = "2025-09-21T20:02:20.44Z" }, + { url = "https://files.pythonhosted.org/packages/8b/51/66420081e72801536a091a0c8f8c1f88a5c4bf7b9b1bdc6222c7afe6dc9b/coverage-7.10.7-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f57b2a3c8353d3e04acf75b3fed57ba41f5c0646bbf1d10c7c282291c97936b4", size = 260592, upload-time = "2025-09-21T20:02:22.313Z" }, + { url = "https://files.pythonhosted.org/packages/5d/22/9b8d458c2881b22df3db5bb3e7369e63d527d986decb6c11a591ba2364f7/coverage-7.10.7-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ef2319dd15a0b009667301a3f84452a4dc6fddfd06b0c5c53ea472d3989fbf0", size = 262768, upload-time = "2025-09-21T20:02:24.287Z" }, + { url = "https://files.pythonhosted.org/packages/f7/08/16bee2c433e60913c610ea200b276e8eeef084b0d200bdcff69920bd5828/coverage-7.10.7-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:83082a57783239717ceb0ad584de3c69cf581b2a95ed6bf81ea66034f00401c0", size = 264995, upload-time = "2025-09-21T20:02:26.133Z" }, + { url = "https://files.pythonhosted.org/packages/20/9d/e53eb9771d154859b084b90201e5221bca7674ba449a17c101a5031d4054/coverage-7.10.7-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:50aa94fb1fb9a397eaa19c0d5ec15a5edd03a47bf1a3a6111a16b36e190cff65", size = 259546, upload-time = "2025-09-21T20:02:27.716Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b0/69bc7050f8d4e56a89fb550a1577d5d0d1db2278106f6f626464067b3817/coverage-7.10.7-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2120043f147bebb41c85b97ac45dd173595ff14f2a584f2963891cbcc3091541", size = 262544, upload-time = "2025-09-21T20:02:29.216Z" }, + { url = "https://files.pythonhosted.org/packages/ef/4b/2514b060dbd1bc0aaf23b852c14bb5818f244c664cb16517feff6bb3a5ab/coverage-7.10.7-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2fafd773231dd0378fdba66d339f84904a8e57a262f583530f4f156ab83863e6", size = 260308, upload-time = "2025-09-21T20:02:31.226Z" }, + { url = "https://files.pythonhosted.org/packages/54/78/7ba2175007c246d75e496f64c06e94122bdb914790a1285d627a918bd271/coverage-7.10.7-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:0b944ee8459f515f28b851728ad224fa2d068f1513ef6b7ff1efafeb2185f999", size = 258920, upload-time = "2025-09-21T20:02:32.823Z" }, + { url = "https://files.pythonhosted.org/packages/c0/b3/fac9f7abbc841409b9a410309d73bfa6cfb2e51c3fada738cb607ce174f8/coverage-7.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4b583b97ab2e3efe1b3e75248a9b333bd3f8b0b1b8e5b45578e05e5850dfb2c2", size = 261434, upload-time = "2025-09-21T20:02:34.86Z" }, + { url = "https://files.pythonhosted.org/packages/ee/51/a03bec00d37faaa891b3ff7387192cef20f01604e5283a5fabc95346befa/coverage-7.10.7-cp313-cp313t-win32.whl", hash = "sha256:2a78cd46550081a7909b3329e2266204d584866e8d97b898cd7fb5ac8d888b1a", size = 221403, upload-time = "2025-09-21T20:02:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/53/22/3cf25d614e64bf6d8e59c7c669b20d6d940bb337bdee5900b9ca41c820bb/coverage-7.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:33a5e6396ab684cb43dc7befa386258acb2d7fae7f67330ebb85ba4ea27938eb", size = 222469, upload-time = "2025-09-21T20:02:39.011Z" }, + { url = "https://files.pythonhosted.org/packages/49/a1/00164f6d30d8a01c3c9c48418a7a5be394de5349b421b9ee019f380df2a0/coverage-7.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:86b0e7308289ddde73d863b7683f596d8d21c7d8664ce1dee061d0bcf3fbb4bb", size = 220731, upload-time = "2025-09-21T20:02:40.939Z" }, + { url = "https://files.pythonhosted.org/packages/23/9c/5844ab4ca6a4dd97a1850e030a15ec7d292b5c5cb93082979225126e35dd/coverage-7.10.7-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b06f260b16ead11643a5a9f955bd4b5fd76c1a4c6796aeade8520095b75de520", size = 218302, upload-time = "2025-09-21T20:02:42.527Z" }, + { url = "https://files.pythonhosted.org/packages/f0/89/673f6514b0961d1f0e20ddc242e9342f6da21eaba3489901b565c0689f34/coverage-7.10.7-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:212f8f2e0612778f09c55dd4872cb1f64a1f2b074393d139278ce902064d5b32", size = 218578, upload-time = "2025-09-21T20:02:44.468Z" }, + { url = "https://files.pythonhosted.org/packages/05/e8/261cae479e85232828fb17ad536765c88dd818c8470aca690b0ac6feeaa3/coverage-7.10.7-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3445258bcded7d4aa630ab8296dea4d3f15a255588dd535f980c193ab6b95f3f", size = 249629, upload-time = "2025-09-21T20:02:46.503Z" }, + { url = "https://files.pythonhosted.org/packages/82/62/14ed6546d0207e6eda876434e3e8475a3e9adbe32110ce896c9e0c06bb9a/coverage-7.10.7-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb45474711ba385c46a0bfe696c695a929ae69ac636cda8f532be9e8c93d720a", size = 252162, upload-time = "2025-09-21T20:02:48.689Z" }, + { url = "https://files.pythonhosted.org/packages/ff/49/07f00db9ac6478e4358165a08fb41b469a1b053212e8a00cb02f0d27a05f/coverage-7.10.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:813922f35bd800dca9994c5971883cbc0d291128a5de6b167c7aa697fcf59360", size = 253517, upload-time = "2025-09-21T20:02:50.31Z" }, + { url = "https://files.pythonhosted.org/packages/a2/59/c5201c62dbf165dfbc91460f6dbbaa85a8b82cfa6131ac45d6c1bfb52deb/coverage-7.10.7-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:93c1b03552081b2a4423091d6fb3787265b8f86af404cff98d1b5342713bdd69", size = 249632, upload-time = "2025-09-21T20:02:51.971Z" }, + { url = "https://files.pythonhosted.org/packages/07/ae/5920097195291a51fb00b3a70b9bbd2edbfe3c84876a1762bd1ef1565ebc/coverage-7.10.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cc87dd1b6eaf0b848eebb1c86469b9f72a1891cb42ac7adcfbce75eadb13dd14", size = 251520, upload-time = "2025-09-21T20:02:53.858Z" }, + { url = "https://files.pythonhosted.org/packages/b9/3c/a815dde77a2981f5743a60b63df31cb322c944843e57dbd579326625a413/coverage-7.10.7-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:39508ffda4f343c35f3236fe8d1a6634a51f4581226a1262769d7f970e73bffe", size = 249455, upload-time = "2025-09-21T20:02:55.807Z" }, + { url = "https://files.pythonhosted.org/packages/aa/99/f5cdd8421ea656abefb6c0ce92556709db2265c41e8f9fc6c8ae0f7824c9/coverage-7.10.7-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:925a1edf3d810537c5a3abe78ec5530160c5f9a26b1f4270b40e62cc79304a1e", size = 249287, upload-time = "2025-09-21T20:02:57.784Z" }, + { url = "https://files.pythonhosted.org/packages/c3/7a/e9a2da6a1fc5d007dd51fca083a663ab930a8c4d149c087732a5dbaa0029/coverage-7.10.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2c8b9a0636f94c43cd3576811e05b89aa9bc2d0a85137affc544ae5cb0e4bfbd", size = 250946, upload-time = "2025-09-21T20:02:59.431Z" }, + { url = "https://files.pythonhosted.org/packages/ef/5b/0b5799aa30380a949005a353715095d6d1da81927d6dbed5def2200a4e25/coverage-7.10.7-cp314-cp314-win32.whl", hash = "sha256:b7b8288eb7cdd268b0304632da8cb0bb93fadcfec2fe5712f7b9cc8f4d487be2", size = 221009, upload-time = "2025-09-21T20:03:01.324Z" }, + { url = "https://files.pythonhosted.org/packages/da/b0/e802fbb6eb746de006490abc9bb554b708918b6774b722bb3a0e6aa1b7de/coverage-7.10.7-cp314-cp314-win_amd64.whl", hash = "sha256:1ca6db7c8807fb9e755d0379ccc39017ce0a84dcd26d14b5a03b78563776f681", size = 221804, upload-time = "2025-09-21T20:03:03.4Z" }, + { url = "https://files.pythonhosted.org/packages/9e/e8/71d0c8e374e31f39e3389bb0bd19e527d46f00ea8571ec7ec8fd261d8b44/coverage-7.10.7-cp314-cp314-win_arm64.whl", hash = "sha256:097c1591f5af4496226d5783d036bf6fd6cd0cbc132e071b33861de756efb880", size = 220384, upload-time = "2025-09-21T20:03:05.111Z" }, + { url = "https://files.pythonhosted.org/packages/62/09/9a5608d319fa3eba7a2019addeacb8c746fb50872b57a724c9f79f146969/coverage-7.10.7-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:a62c6ef0d50e6de320c270ff91d9dd0a05e7250cac2a800b7784bae474506e63", size = 219047, upload-time = "2025-09-21T20:03:06.795Z" }, + { url = "https://files.pythonhosted.org/packages/f5/6f/f58d46f33db9f2e3647b2d0764704548c184e6f5e014bef528b7f979ef84/coverage-7.10.7-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:9fa6e4dd51fe15d8738708a973470f67a855ca50002294852e9571cdbd9433f2", size = 219266, upload-time = "2025-09-21T20:03:08.495Z" }, + { url = "https://files.pythonhosted.org/packages/74/5c/183ffc817ba68e0b443b8c934c8795553eb0c14573813415bd59941ee165/coverage-7.10.7-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:8fb190658865565c549b6b4706856d6a7b09302c797eb2cf8e7fe9dabb043f0d", size = 260767, upload-time = "2025-09-21T20:03:10.172Z" }, + { url = "https://files.pythonhosted.org/packages/0f/48/71a8abe9c1ad7e97548835e3cc1adbf361e743e9d60310c5f75c9e7bf847/coverage-7.10.7-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:affef7c76a9ef259187ef31599a9260330e0335a3011732c4b9effa01e1cd6e0", size = 262931, upload-time = "2025-09-21T20:03:11.861Z" }, + { url = "https://files.pythonhosted.org/packages/84/fd/193a8fb132acfc0a901f72020e54be5e48021e1575bb327d8ee1097a28fd/coverage-7.10.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e16e07d85ca0cf8bafe5f5d23a0b850064e8e945d5677492b06bbe6f09cc699", size = 265186, upload-time = "2025-09-21T20:03:13.539Z" }, + { url = "https://files.pythonhosted.org/packages/b1/8f/74ecc30607dd95ad50e3034221113ccb1c6d4e8085cc761134782995daae/coverage-7.10.7-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:03ffc58aacdf65d2a82bbeb1ffe4d01ead4017a21bfd0454983b88ca73af94b9", size = 259470, upload-time = "2025-09-21T20:03:15.584Z" }, + { url = "https://files.pythonhosted.org/packages/0f/55/79ff53a769f20d71b07023ea115c9167c0bb56f281320520cf64c5298a96/coverage-7.10.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1b4fd784344d4e52647fd7857b2af5b3fbe6c239b0b5fa63e94eb67320770e0f", size = 262626, upload-time = "2025-09-21T20:03:17.673Z" }, + { url = "https://files.pythonhosted.org/packages/88/e2/dac66c140009b61ac3fc13af673a574b00c16efdf04f9b5c740703e953c0/coverage-7.10.7-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:0ebbaddb2c19b71912c6f2518e791aa8b9f054985a0769bdb3a53ebbc765c6a1", size = 260386, upload-time = "2025-09-21T20:03:19.36Z" }, + { url = "https://files.pythonhosted.org/packages/a2/f1/f48f645e3f33bb9ca8a496bc4a9671b52f2f353146233ebd7c1df6160440/coverage-7.10.7-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:a2d9a3b260cc1d1dbdb1c582e63ddcf5363426a1a68faa0f5da28d8ee3c722a0", size = 258852, upload-time = "2025-09-21T20:03:21.007Z" }, + { url = "https://files.pythonhosted.org/packages/bb/3b/8442618972c51a7affeead957995cfa8323c0c9bcf8fa5a027421f720ff4/coverage-7.10.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a3cc8638b2480865eaa3926d192e64ce6c51e3d29c849e09d5b4ad95efae5399", size = 261534, upload-time = "2025-09-21T20:03:23.12Z" }, + { url = "https://files.pythonhosted.org/packages/b2/dc/101f3fa3a45146db0cb03f5b4376e24c0aac818309da23e2de0c75295a91/coverage-7.10.7-cp314-cp314t-win32.whl", hash = "sha256:67f8c5cbcd3deb7a60b3345dffc89a961a484ed0af1f6f73de91705cc6e31235", size = 221784, upload-time = "2025-09-21T20:03:24.769Z" }, + { url = "https://files.pythonhosted.org/packages/4c/a1/74c51803fc70a8a40d7346660379e144be772bab4ac7bb6e6b905152345c/coverage-7.10.7-cp314-cp314t-win_amd64.whl", hash = "sha256:e1ed71194ef6dea7ed2d5cb5f7243d4bcd334bfb63e59878519be558078f848d", size = 222905, upload-time = "2025-09-21T20:03:26.93Z" }, + { url = "https://files.pythonhosted.org/packages/12/65/f116a6d2127df30bcafbceef0302d8a64ba87488bf6f73a6d8eebf060873/coverage-7.10.7-cp314-cp314t-win_arm64.whl", hash = "sha256:7fe650342addd8524ca63d77b2362b02345e5f1a093266787d210c70a50b471a", size = 220922, upload-time = "2025-09-21T20:03:28.672Z" }, + { url = "https://files.pythonhosted.org/packages/ec/16/114df1c291c22cac3b0c127a73e0af5c12ed7bbb6558d310429a0ae24023/coverage-7.10.7-py3-none-any.whl", hash = "sha256:f7941f6f2fe6dd6807a1208737b8a0cbcf1cc6d7b07d24998ad2d63590868260", size = 209952, upload-time = "2025-09-21T20:03:53.918Z" }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + +[[package]] +name = "cryptography" +version = "45.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/1e/49527ac611af559665f71cbb8f92b332b5ec9c6fbc4e88b0f8e92f5e85df/cryptography-45.0.5.tar.gz", hash = "sha256:72e76caa004ab63accdf26023fccd1d087f6d90ec6048ff33ad0445abf7f605a", size = 744903, upload-time = "2025-07-02T13:06:25.941Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/fb/09e28bc0c46d2c547085e60897fea96310574c70fb21cd58a730a45f3403/cryptography-45.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:101ee65078f6dd3e5a028d4f19c07ffa4dd22cce6a20eaa160f8b5219911e7d8", size = 7043092, upload-time = "2025-07-02T13:05:01.514Z" }, + { url = "https://files.pythonhosted.org/packages/b1/05/2194432935e29b91fb649f6149c1a4f9e6d3d9fc880919f4ad1bcc22641e/cryptography-45.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3a264aae5f7fbb089dbc01e0242d3b67dffe3e6292e1f5182122bdf58e65215d", size = 4205926, upload-time = "2025-07-02T13:05:04.741Z" }, + { url = "https://files.pythonhosted.org/packages/07/8b/9ef5da82350175e32de245646b1884fc01124f53eb31164c77f95a08d682/cryptography-45.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e74d30ec9c7cb2f404af331d5b4099a9b322a8a6b25c4632755c8757345baac5", size = 4429235, upload-time = "2025-07-02T13:05:07.084Z" }, + { url = "https://files.pythonhosted.org/packages/7c/e1/c809f398adde1994ee53438912192d92a1d0fc0f2d7582659d9ef4c28b0c/cryptography-45.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3af26738f2db354aafe492fb3869e955b12b2ef2e16908c8b9cb928128d42c57", size = 4209785, upload-time = "2025-07-02T13:05:09.321Z" }, + { url = "https://files.pythonhosted.org/packages/d0/8b/07eb6bd5acff58406c5e806eff34a124936f41a4fb52909ffa4d00815f8c/cryptography-45.0.5-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e6c00130ed423201c5bc5544c23359141660b07999ad82e34e7bb8f882bb78e0", size = 3893050, upload-time = "2025-07-02T13:05:11.069Z" }, + { url = "https://files.pythonhosted.org/packages/ec/ef/3333295ed58d900a13c92806b67e62f27876845a9a908c939f040887cca9/cryptography-45.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:dd420e577921c8c2d31289536c386aaa30140b473835e97f83bc71ea9d2baf2d", size = 4457379, upload-time = "2025-07-02T13:05:13.32Z" }, + { url = "https://files.pythonhosted.org/packages/d9/9d/44080674dee514dbb82b21d6fa5d1055368f208304e2ab1828d85c9de8f4/cryptography-45.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:d05a38884db2ba215218745f0781775806bde4f32e07b135348355fe8e4991d9", size = 4209355, upload-time = "2025-07-02T13:05:15.017Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d8/0749f7d39f53f8258e5c18a93131919ac465ee1f9dccaf1b3f420235e0b5/cryptography-45.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:ad0caded895a00261a5b4aa9af828baede54638754b51955a0ac75576b831b27", size = 4456087, upload-time = "2025-07-02T13:05:16.945Z" }, + { url = "https://files.pythonhosted.org/packages/09/d7/92acac187387bf08902b0bf0699816f08553927bdd6ba3654da0010289b4/cryptography-45.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9024beb59aca9d31d36fcdc1604dd9bbeed0a55bface9f1908df19178e2f116e", size = 4332873, upload-time = "2025-07-02T13:05:18.743Z" }, + { url = "https://files.pythonhosted.org/packages/03/c2/840e0710da5106a7c3d4153c7215b2736151bba60bf4491bdb421df5056d/cryptography-45.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:91098f02ca81579c85f66df8a588c78f331ca19089763d733e34ad359f474174", size = 4564651, upload-time = "2025-07-02T13:05:21.382Z" }, + { url = "https://files.pythonhosted.org/packages/2e/92/cc723dd6d71e9747a887b94eb3827825c6c24b9e6ce2bb33b847d31d5eaa/cryptography-45.0.5-cp311-abi3-win32.whl", hash = "sha256:926c3ea71a6043921050eaa639137e13dbe7b4ab25800932a8498364fc1abec9", size = 2929050, upload-time = "2025-07-02T13:05:23.39Z" }, + { url = "https://files.pythonhosted.org/packages/1f/10/197da38a5911a48dd5389c043de4aec4b3c94cb836299b01253940788d78/cryptography-45.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:b85980d1e345fe769cfc57c57db2b59cff5464ee0c045d52c0df087e926fbe63", size = 3403224, upload-time = "2025-07-02T13:05:25.202Z" }, + { url = "https://files.pythonhosted.org/packages/fe/2b/160ce8c2765e7a481ce57d55eba1546148583e7b6f85514472b1d151711d/cryptography-45.0.5-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f3562c2f23c612f2e4a6964a61d942f891d29ee320edb62ff48ffb99f3de9ae8", size = 7017143, upload-time = "2025-07-02T13:05:27.229Z" }, + { url = "https://files.pythonhosted.org/packages/c2/e7/2187be2f871c0221a81f55ee3105d3cf3e273c0a0853651d7011eada0d7e/cryptography-45.0.5-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3fcfbefc4a7f332dece7272a88e410f611e79458fab97b5efe14e54fe476f4fd", size = 4197780, upload-time = "2025-07-02T13:05:29.299Z" }, + { url = "https://files.pythonhosted.org/packages/b9/cf/84210c447c06104e6be9122661159ad4ce7a8190011669afceeaea150524/cryptography-45.0.5-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:460f8c39ba66af7db0545a8c6f2eabcbc5a5528fc1cf6c3fa9a1e44cec33385e", size = 4420091, upload-time = "2025-07-02T13:05:31.221Z" }, + { url = "https://files.pythonhosted.org/packages/3e/6a/cb8b5c8bb82fafffa23aeff8d3a39822593cee6e2f16c5ca5c2ecca344f7/cryptography-45.0.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:9b4cf6318915dccfe218e69bbec417fdd7c7185aa7aab139a2c0beb7468c89f0", size = 4198711, upload-time = "2025-07-02T13:05:33.062Z" }, + { url = "https://files.pythonhosted.org/packages/04/f7/36d2d69df69c94cbb2473871926daf0f01ad8e00fe3986ac3c1e8c4ca4b3/cryptography-45.0.5-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2089cc8f70a6e454601525e5bf2779e665d7865af002a5dec8d14e561002e135", size = 3883299, upload-time = "2025-07-02T13:05:34.94Z" }, + { url = "https://files.pythonhosted.org/packages/82/c7/f0ea40f016de72f81288e9fe8d1f6748036cb5ba6118774317a3ffc6022d/cryptography-45.0.5-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0027d566d65a38497bc37e0dd7c2f8ceda73597d2ac9ba93810204f56f52ebc7", size = 4450558, upload-time = "2025-07-02T13:05:37.288Z" }, + { url = "https://files.pythonhosted.org/packages/06/ae/94b504dc1a3cdf642d710407c62e86296f7da9e66f27ab12a1ee6fdf005b/cryptography-45.0.5-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:be97d3a19c16a9be00edf79dca949c8fa7eff621763666a145f9f9535a5d7f42", size = 4198020, upload-time = "2025-07-02T13:05:39.102Z" }, + { url = "https://files.pythonhosted.org/packages/05/2b/aaf0adb845d5dabb43480f18f7ca72e94f92c280aa983ddbd0bcd6ecd037/cryptography-45.0.5-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:7760c1c2e1a7084153a0f68fab76e754083b126a47d0117c9ed15e69e2103492", size = 4449759, upload-time = "2025-07-02T13:05:41.398Z" }, + { url = "https://files.pythonhosted.org/packages/91/e4/f17e02066de63e0100a3a01b56f8f1016973a1d67551beaf585157a86b3f/cryptography-45.0.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6ff8728d8d890b3dda5765276d1bc6fb099252915a2cd3aff960c4c195745dd0", size = 4319991, upload-time = "2025-07-02T13:05:43.64Z" }, + { url = "https://files.pythonhosted.org/packages/f2/2e/e2dbd629481b499b14516eed933f3276eb3239f7cee2dcfa4ee6b44d4711/cryptography-45.0.5-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7259038202a47fdecee7e62e0fd0b0738b6daa335354396c6ddebdbe1206af2a", size = 4554189, upload-time = "2025-07-02T13:05:46.045Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ea/a78a0c38f4c8736287b71c2ea3799d173d5ce778c7d6e3c163a95a05ad2a/cryptography-45.0.5-cp37-abi3-win32.whl", hash = "sha256:1e1da5accc0c750056c556a93c3e9cb828970206c68867712ca5805e46dc806f", size = 2911769, upload-time = "2025-07-02T13:05:48.329Z" }, + { url = "https://files.pythonhosted.org/packages/79/b3/28ac139109d9005ad3f6b6f8976ffede6706a6478e21c889ce36c840918e/cryptography-45.0.5-cp37-abi3-win_amd64.whl", hash = "sha256:90cb0a7bb35959f37e23303b7eed0a32280510030daba3f7fdfbb65defde6a97", size = 3390016, upload-time = "2025-07-02T13:05:50.811Z" }, + { url = "https://files.pythonhosted.org/packages/f8/8b/34394337abe4566848a2bd49b26bcd4b07fd466afd3e8cce4cb79a390869/cryptography-45.0.5-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:206210d03c1193f4e1ff681d22885181d47efa1ab3018766a7b32a7b3d6e6afd", size = 3575762, upload-time = "2025-07-02T13:05:53.166Z" }, + { url = "https://files.pythonhosted.org/packages/8b/5d/a19441c1e89afb0f173ac13178606ca6fab0d3bd3ebc29e9ed1318b507fc/cryptography-45.0.5-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c648025b6840fe62e57107e0a25f604db740e728bd67da4f6f060f03017d5097", size = 4140906, upload-time = "2025-07-02T13:05:55.914Z" }, + { url = "https://files.pythonhosted.org/packages/4b/db/daceb259982a3c2da4e619f45b5bfdec0e922a23de213b2636e78ef0919b/cryptography-45.0.5-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b8fa8b0a35a9982a3c60ec79905ba5bb090fc0b9addcfd3dc2dd04267e45f25e", size = 4374411, upload-time = "2025-07-02T13:05:57.814Z" }, + { url = "https://files.pythonhosted.org/packages/6a/35/5d06ad06402fc522c8bf7eab73422d05e789b4e38fe3206a85e3d6966c11/cryptography-45.0.5-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:14d96584701a887763384f3c47f0ca7c1cce322aa1c31172680eb596b890ec30", size = 4140942, upload-time = "2025-07-02T13:06:00.137Z" }, + { url = "https://files.pythonhosted.org/packages/65/79/020a5413347e44c382ef1f7f7e7a66817cd6273e3e6b5a72d18177b08b2f/cryptography-45.0.5-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:57c816dfbd1659a367831baca4b775b2a5b43c003daf52e9d57e1d30bc2e1b0e", size = 4374079, upload-time = "2025-07-02T13:06:02.043Z" }, + { url = "https://files.pythonhosted.org/packages/9b/c5/c0e07d84a9a2a8a0ed4f865e58f37c71af3eab7d5e094ff1b21f3f3af3bc/cryptography-45.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:b9e38e0a83cd51e07f5a48ff9691cae95a79bea28fe4ded168a8e5c6c77e819d", size = 3321362, upload-time = "2025-07-02T13:06:04.463Z" }, + { url = "https://files.pythonhosted.org/packages/c0/71/9bdbcfd58d6ff5084687fe722c58ac718ebedbc98b9f8f93781354e6d286/cryptography-45.0.5-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8c4a6ff8a30e9e3d38ac0539e9a9e02540ab3f827a3394f8852432f6b0ea152e", size = 3587878, upload-time = "2025-07-02T13:06:06.339Z" }, + { url = "https://files.pythonhosted.org/packages/f0/63/83516cfb87f4a8756eaa4203f93b283fda23d210fc14e1e594bd5f20edb6/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bd4c45986472694e5121084c6ebbd112aa919a25e783b87eb95953c9573906d6", size = 4152447, upload-time = "2025-07-02T13:06:08.345Z" }, + { url = "https://files.pythonhosted.org/packages/22/11/d2823d2a5a0bd5802b3565437add16f5c8ce1f0778bf3822f89ad2740a38/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:982518cd64c54fcada9d7e5cf28eabd3ee76bd03ab18e08a48cad7e8b6f31b18", size = 4386778, upload-time = "2025-07-02T13:06:10.263Z" }, + { url = "https://files.pythonhosted.org/packages/5f/38/6bf177ca6bce4fe14704ab3e93627c5b0ca05242261a2e43ef3168472540/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:12e55281d993a793b0e883066f590c1ae1e802e3acb67f8b442e721e475e6463", size = 4151627, upload-time = "2025-07-02T13:06:13.097Z" }, + { url = "https://files.pythonhosted.org/packages/38/6a/69fc67e5266bff68a91bcb81dff8fb0aba4d79a78521a08812048913e16f/cryptography-45.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:5aa1e32983d4443e310f726ee4b071ab7569f58eedfdd65e9675484a4eb67bd1", size = 4385593, upload-time = "2025-07-02T13:06:15.689Z" }, + { url = "https://files.pythonhosted.org/packages/f6/34/31a1604c9a9ade0fdab61eb48570e09a796f4d9836121266447b0eaf7feb/cryptography-45.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:e357286c1b76403dd384d938f93c46b2b058ed4dfcdce64a770f0537ed3feb6f", size = 3331106, upload-time = "2025-07-02T13:06:18.058Z" }, +] + [[package]] name = "cssselect2" version = "0.8.0" @@ -611,10 +773,13 @@ dependencies = [ { name = "jsonschema" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "python-multipart" }, { name = "pywin32", marker = "sys_platform == 'win32'" }, { name = "sse-starlette" }, { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] @@ -632,6 +797,7 @@ ws = [ [package.dev-dependencies] dev = [ + { name = "coverage", extra = ["toml"] }, { name = "dirty-equals" }, { name = "inline-snapshot" }, { name = "pyright" }, @@ -658,6 +824,7 @@ requires-dist = [ { name = "jsonschema", specifier = ">=4.20.0" }, { name = "pydantic", specifier = ">=2.11.0,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.1" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "python-multipart", specifier = ">=0.0.9" }, { name = "pywin32", marker = "sys_platform == 'win32'", specifier = ">=310" }, @@ -665,6 +832,8 @@ requires-dist = [ { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.16.0" }, + { name = "typing-extensions", specifier = ">=4.9.0" }, + { name = "typing-inspection", specifier = ">=0.4.1" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.31.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] @@ -672,6 +841,7 @@ provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ + { name = "coverage", extras = ["toml"], specifier = "==7.10.7" }, { name = "dirty-equals", specifier = ">=0.9.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "pyright", specifier = ">=1.1.400" }, @@ -690,6 +860,72 @@ docs = [ { name = "mkdocstrings-python", specifier = ">=1.12.2" }, ] +[[package]] +name = "mcp-conformance-auth-client" +version = "0.1.0" +source = { editable = "examples/clients/conformance-auth-client" } +dependencies = [ + { name = "httpx" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "httpx", specifier = ">=0.28.1" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-everything-server" +version = "0.1.0" +source = { editable = "examples/servers/everything-server" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.2.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-auth" version = "0.1.0" @@ -731,6 +967,68 @@ dev = [ { name = "ruff", specifier = ">=0.8.5" }, ] +[[package]] +name = "mcp-simple-auth-client" +version = "0.1.0" +source = { editable = "examples/clients/simple-auth-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.2.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-chatbot" +version = "0.1.0" +source = { editable = "examples/clients/simple-chatbot" } +dependencies = [ + { name = "mcp" }, + { name = "python-dotenv" }, + { name = "requests" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "mcp", editable = "." }, + { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "requests", specifier = ">=2.31.0" }, + { name = "uvicorn", specifier = ">=0.32.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.379" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-pagination" version = "0.1.0" @@ -904,6 +1202,126 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-task" +version = "0.1.0" +source = { editable = "examples/servers/simple-task" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-task-client" +version = "0.1.0" +source = { editable = "examples/clients/simple-task-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-task-interactive" +version = "0.1.0" +source = { editable = "examples/servers/simple-task-interactive" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-task-interactive-client" +version = "0.1.0" +source = { editable = "examples/clients/simple-task-interactive-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0" @@ -948,6 +1366,83 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "mcp", editable = "." }] +[[package]] +name = "mcp-sse-polling-client" +version = "0.1.0" +source = { editable = "examples/clients/sse-polling-client" } +dependencies = [ + { name = "click" }, + { name = "mcp" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.2.0" }, + { name = "mcp", editable = "." }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-sse-polling-demo" +version = "0.1.0" +source = { editable = "examples/servers/sse-polling-demo" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.2.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-structured-output-lowlevel" +version = "0.1.0" +source = { virtual = "examples/servers/structured-output-lowlevel" } +dependencies = [ + { name = "mcp" }, +] + +[package.metadata] +requires-dist = [{ name = "mcp", editable = "." }] + [[package]] name = "mdurl" version = "0.1.2" @@ -1411,6 +1906,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pymdown-extensions" version = "10.16.1" @@ -1916,15 +2425,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.47.3" +version = "0.49.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/15/b9/cc3017f9a9c9b6e27c5106cc10cc7904653c3eec0729793aec10479dd669/starlette-0.47.3.tar.gz", hash = "sha256:6bc94f839cc176c4858894f1f8908f0ab79dfec1a6b8402f6da9be26ebea52e9", size = 2584144, upload-time = "2025-08-24T13:36:42.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/fd/901cfa59aaa5b30a99e16876f11abe38b59a1a2c51ffb3d7142bb6089069/starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51", size = 72991, upload-time = "2025-08-24T13:36:40.887Z" }, + { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, ] [[package]]