From abb2f0aa14f94f61be2e1f72711aa35284068f5b Mon Sep 17 00:00:00 2001 From: Marcelo Paternostro <64930576+m-paternostro@users.noreply.github.com> Date: Mon, 18 Aug 2025 14:11:15 +0100 Subject: [PATCH 1/4] =?UTF-8?q?feature(middleware):=20Composable=20fetch?= =?UTF-8?q?=20middleware=20for=20auth=20and=20cross=E2=80=91cutting=20conc?= =?UTF-8?q?erns=20=20(#485)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Soria Parra Co-authored-by: Claude --- src/client/middleware.test.ts | 1213 +++++++++++++++++++++++++++++++++ src/client/middleware.ts | 358 ++++++++++ 2 files changed, 1571 insertions(+) create mode 100644 src/client/middleware.test.ts create mode 100644 src/client/middleware.ts diff --git a/src/client/middleware.test.ts b/src/client/middleware.test.ts new file mode 100644 index 000000000..265aa70d6 --- /dev/null +++ b/src/client/middleware.test.ts @@ -0,0 +1,1213 @@ +import { + withOAuth, + withLogging, + applyMiddlewares, + createMiddleware, +} from "./middleware.js"; +import { OAuthClientProvider } from "./auth.js"; +import { FetchLike } from "../shared/transport.js"; + +jest.mock("../client/auth.js", () => { + const actual = jest.requireActual("../client/auth.js"); + return { + ...actual, + auth: jest.fn(), + extractResourceMetadataUrl: jest.fn(), + }; +}); + +import { auth, extractResourceMetadataUrl } from "./auth.js"; + +const mockAuth = auth as jest.MockedFunction; +const mockExtractResourceMetadataUrl = + extractResourceMetadataUrl as jest.MockedFunction< + typeof extractResourceMetadataUrl + >; + +describe("withOAuth", () => { + let mockProvider: jest.Mocked; + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + + mockProvider = { + get redirectUrl() { + return "http://localhost/callback"; + }, + get clientMetadata() { + return { redirect_uris: ["http://localhost/callback"] }; + }, + tokens: jest.fn(), + saveTokens: jest.fn(), + clientInformation: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + invalidateCredentials: jest.fn(), + }; + + mockFetch = jest.fn(); + }); + + it("should add Authorization header when tokens are available (with explicit baseUrl)", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + + const enhancedFetch = withOAuth( + mockProvider, + "https://api.example.com/", + )(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + expect(mockFetch).toHaveBeenCalledWith( + "https://api.example.com/data", + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBe("Bearer test-token"); + }); + + it("should add Authorization header when tokens are available (without baseUrl)", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + + // Test without baseUrl - should extract from request URL + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + expect(mockFetch).toHaveBeenCalledWith( + "https://api.example.com/data", + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBe("Bearer test-token"); + }); + + it("should handle requests without tokens (without baseUrl)", async () => { + mockProvider.tokens.mockResolvedValue(undefined); + mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + + // Test without baseUrl + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + expect(mockFetch).toHaveBeenCalledTimes(1); + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBeNull(); + }); + + it("should retry request after successful auth on 401 response (with explicit baseUrl)", async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: "old-token", + token_type: "Bearer", + expires_in: 3600, + }) + .mockResolvedValueOnce({ + access_token: "new-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const unauthorizedResponse = new Response("Unauthorized", { + status: 401, + headers: { "www-authenticate": 'Bearer realm="oauth"' }, + }); + const successResponse = new Response("success", { status: 200 }); + + mockFetch + .mockResolvedValueOnce(unauthorizedResponse) + .mockResolvedValueOnce(successResponse); + + const mockResourceUrl = new URL( + "https://oauth.example.com/.well-known/oauth-protected-resource", + ); + mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl); + mockAuth.mockResolvedValue("AUTHORIZED"); + + const enhancedFetch = withOAuth( + mockProvider, + "https://api.example.com/", + )(mockFetch); + + const result = await enhancedFetch("https://api.example.com/data"); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: "https://api.example.com/", + resourceMetadataUrl: mockResourceUrl, + fetchFn: mockFetch, + }); + + // Verify the retry used the new token + const retryCallArgs = mockFetch.mock.calls[1]; + const retryHeaders = retryCallArgs[1]?.headers as Headers; + expect(retryHeaders.get("Authorization")).toBe("Bearer new-token"); + }); + + it("should retry request after successful auth on 401 response (without baseUrl)", async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: "old-token", + token_type: "Bearer", + expires_in: 3600, + }) + .mockResolvedValueOnce({ + access_token: "new-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const unauthorizedResponse = new Response("Unauthorized", { + status: 401, + headers: { "www-authenticate": 'Bearer realm="oauth"' }, + }); + const successResponse = new Response("success", { status: 200 }); + + mockFetch + .mockResolvedValueOnce(unauthorizedResponse) + .mockResolvedValueOnce(successResponse); + + const mockResourceUrl = new URL( + "https://oauth.example.com/.well-known/oauth-protected-resource", + ); + mockExtractResourceMetadataUrl.mockReturnValue(mockResourceUrl); + mockAuth.mockResolvedValue("AUTHORIZED"); + + // Test without baseUrl - should extract from request URL + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + const result = await enhancedFetch("https://api.example.com/data"); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: "https://api.example.com/", // Should be extracted from request URL + resourceMetadataUrl: mockResourceUrl, + fetchFn: mockFetch, + }); + + // Verify the retry used the new token + const retryCallArgs = mockFetch.mock.calls[1]; + const retryHeaders = retryCallArgs[1]?.headers as Headers; + expect(retryHeaders.get("Authorization")).toBe("Bearer new-token"); + }); + + it("should throw UnauthorizedError when auth returns REDIRECT (without baseUrl)", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + mockFetch.mockResolvedValue(new Response("Unauthorized", { status: 401 })); + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockResolvedValue("REDIRECT"); + + // Test without baseUrl + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( + "Authentication requires user authorization - redirect initiated", + ); + }); + + it("should throw UnauthorizedError when auth fails", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + mockFetch.mockResolvedValue(new Response("Unauthorized", { status: 401 })); + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockRejectedValue(new Error("Network error")); + + const enhancedFetch = withOAuth( + mockProvider, + "https://api.example.com/", + )(mockFetch); + + await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( + "Failed to re-authenticate: Network error", + ); + }); + + it("should handle persistent 401 responses after auth", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + // Always return 401 + mockFetch.mockResolvedValue(new Response("Unauthorized", { status: 401 })); + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockResolvedValue("AUTHORIZED"); + + const enhancedFetch = withOAuth( + mockProvider, + "https://api.example.com/", + )(mockFetch); + + await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( + "Authentication failed for https://api.example.com/data", + ); + + // Should have made initial request + 1 retry after auth = 2 total + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledTimes(1); + }); + + it("should preserve original request method and body", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + + const enhancedFetch = withOAuth( + mockProvider, + "https://api.example.com/", + )(mockFetch); + + const requestBody = JSON.stringify({ data: "test" }); + await enhancedFetch("https://api.example.com/data", { + method: "POST", + body: requestBody, + headers: { "Content-Type": "application/json" }, + }); + + expect(mockFetch).toHaveBeenCalledWith( + "https://api.example.com/data", + expect.objectContaining({ + method: "POST", + body: requestBody, + headers: expect.any(Headers), + }), + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Content-Type")).toBe("application/json"); + expect(headers.get("Authorization")).toBe("Bearer test-token"); + }); + + it("should handle non-401 errors normally", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const serverErrorResponse = new Response("Server Error", { status: 500 }); + mockFetch.mockResolvedValue(serverErrorResponse); + + const enhancedFetch = withOAuth( + mockProvider, + "https://api.example.com/", + )(mockFetch); + + const result = await enhancedFetch("https://api.example.com/data"); + + expect(result).toBe(serverErrorResponse); + expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockAuth).not.toHaveBeenCalled(); + }); + + it("should handle URL object as input (without baseUrl)", async () => { + mockProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + }); + + mockFetch.mockResolvedValue(new Response("success", { status: 200 })); + + // Test URL object without baseUrl - should extract origin from URL object + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + await enhancedFetch(new URL("https://api.example.com/data")); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + }); + + it("should handle URL object in auth retry (without baseUrl)", async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: "old-token", + token_type: "Bearer", + expires_in: 3600, + }) + .mockResolvedValueOnce({ + access_token: "new-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const unauthorizedResponse = new Response("Unauthorized", { status: 401 }); + const successResponse = new Response("success", { status: 200 }); + + mockFetch + .mockResolvedValueOnce(unauthorizedResponse) + .mockResolvedValueOnce(successResponse); + + mockExtractResourceMetadataUrl.mockReturnValue(undefined); + mockAuth.mockResolvedValue("AUTHORIZED"); + + const enhancedFetch = withOAuth(mockProvider)(mockFetch); + + const result = await enhancedFetch(new URL("https://api.example.com/data")); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: "https://api.example.com/", // Should extract origin from URL object + resourceMetadataUrl: undefined, + fetchFn: mockFetch, + }); + }); +}); + +describe("withLogging", () => { + let mockFetch: jest.MockedFunction; + let mockLogger: jest.MockedFunction< + (input: { + method: string; + url: string | URL; + status: number; + statusText: string; + duration: number; + requestHeaders?: Headers; + responseHeaders?: Headers; + error?: Error; + }) => void + >; + let consoleErrorSpy: jest.SpyInstance; + let consoleLogSpy: jest.SpyInstance; + + beforeEach(() => { + jest.clearAllMocks(); + + consoleErrorSpy = jest.spyOn(console, "error").mockImplementation(() => {}); + consoleLogSpy = jest.spyOn(console, "log").mockImplementation(() => {}); + + mockFetch = jest.fn(); + mockLogger = jest.fn(); + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + consoleLogSpy.mockRestore(); + }); + + it("should log successful requests with default logger", async () => { + const response = new Response("success", { status: 200, statusText: "OK" }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging()(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringMatching( + /HTTP GET https:\/\/api\.example\.com\/data 200 OK \(\d+\.\d+ms\)/, + ), + ); + }); + + it("should log error responses with default logger", async () => { + const response = new Response("Not Found", { + status: 404, + statusText: "Not Found", + }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging()(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringMatching( + /HTTP GET https:\/\/api\.example\.com\/data 404 Not Found \(\d+\.\d+ms\)/, + ), + ); + }); + + it("should log network errors with default logger", async () => { + const networkError = new Error("Network connection failed"); + mockFetch.mockRejectedValue(networkError); + + const enhancedFetch = withLogging()(mockFetch); + + await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( + "Network connection failed", + ); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringMatching( + /HTTP GET https:\/\/api\.example\.com\/data failed: Network connection failed \(\d+\.\d+ms\)/, + ), + ); + }); + + it("should use custom logger when provided", async () => { + const response = new Response("success", { status: 200, statusText: "OK" }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ logger: mockLogger })(mockFetch); + + await enhancedFetch("https://api.example.com/data", { method: "POST" }); + + expect(mockLogger).toHaveBeenCalledWith({ + method: "POST", + url: "https://api.example.com/data", + status: 200, + statusText: "OK", + duration: expect.any(Number), + requestHeaders: undefined, + responseHeaders: undefined, + }); + + expect(consoleLogSpy).not.toHaveBeenCalled(); + }); + + it("should include request headers when configured", async () => { + const response = new Response("success", { status: 200, statusText: "OK" }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ + logger: mockLogger, + includeRequestHeaders: true, + })(mockFetch); + + await enhancedFetch("https://api.example.com/data", { + headers: { + Authorization: "Bearer token", + "Content-Type": "application/json", + }, + }); + + expect(mockLogger).toHaveBeenCalledWith({ + method: "GET", + url: "https://api.example.com/data", + status: 200, + statusText: "OK", + duration: expect.any(Number), + requestHeaders: expect.any(Headers), + responseHeaders: undefined, + }); + + const logCall = mockLogger.mock.calls[0][0]; + expect(logCall.requestHeaders?.get("Authorization")).toBe("Bearer token"); + expect(logCall.requestHeaders?.get("Content-Type")).toBe( + "application/json", + ); + }); + + it("should include response headers when configured", async () => { + const response = new Response("success", { + status: 200, + statusText: "OK", + headers: { + "Content-Type": "application/json", + "Cache-Control": "no-cache", + }, + }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ + logger: mockLogger, + includeResponseHeaders: true, + })(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + const logCall = mockLogger.mock.calls[0][0]; + expect(logCall.responseHeaders?.get("Content-Type")).toBe( + "application/json", + ); + expect(logCall.responseHeaders?.get("Cache-Control")).toBe("no-cache"); + }); + + it("should respect statusLevel option", async () => { + const successResponse = new Response("success", { + status: 200, + statusText: "OK", + }); + const errorResponse = new Response("Server Error", { + status: 500, + statusText: "Internal Server Error", + }); + + mockFetch + .mockResolvedValueOnce(successResponse) + .mockResolvedValueOnce(errorResponse); + + const enhancedFetch = withLogging({ + logger: mockLogger, + statusLevel: 400, + })(mockFetch); + + // 200 response should not be logged (below statusLevel 400) + await enhancedFetch("https://api.example.com/success"); + expect(mockLogger).not.toHaveBeenCalled(); + + // 500 response should be logged (above statusLevel 400) + await enhancedFetch("https://api.example.com/error"); + expect(mockLogger).toHaveBeenCalledWith({ + method: "GET", + url: "https://api.example.com/error", + status: 500, + statusText: "Internal Server Error", + duration: expect.any(Number), + requestHeaders: undefined, + responseHeaders: undefined, + }); + }); + + it("should always log network errors regardless of statusLevel", async () => { + const networkError = new Error("Connection timeout"); + mockFetch.mockRejectedValue(networkError); + + const enhancedFetch = withLogging({ + logger: mockLogger, + statusLevel: 500, // Very high log level + })(mockFetch); + + await expect(enhancedFetch("https://api.example.com/data")).rejects.toThrow( + "Connection timeout", + ); + + expect(mockLogger).toHaveBeenCalledWith({ + method: "GET", + url: "https://api.example.com/data", + status: 0, + statusText: "Network Error", + duration: expect.any(Number), + requestHeaders: undefined, + error: networkError, + }); + }); + + it("should include headers in default logger message when configured", async () => { + const response = new Response("success", { + status: 200, + statusText: "OK", + headers: { "Content-Type": "application/json" }, + }); + mockFetch.mockResolvedValue(response); + + const enhancedFetch = withLogging({ + includeRequestHeaders: true, + includeResponseHeaders: true, + })(mockFetch); + + await enhancedFetch("https://api.example.com/data", { + headers: { Authorization: "Bearer token" }, + }); + + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringContaining("Request Headers: {authorization: Bearer token}"), + ); + expect(consoleLogSpy).toHaveBeenCalledWith( + expect.stringContaining( + "Response Headers: {content-type: application/json}", + ), + ); + }); + + it("should measure request duration accurately", async () => { + // Mock a slow response + const response = new Response("success", { status: 200 }); + mockFetch.mockImplementation(async () => { + await new Promise((resolve) => setTimeout(resolve, 100)); + return response; + }); + + const enhancedFetch = withLogging({ logger: mockLogger })(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + const logCall = mockLogger.mock.calls[0][0]; + expect(logCall.duration).toBeGreaterThanOrEqual(90); // Allow some margin for timing + }); +}); + +describe("applyMiddleware", () => { + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + mockFetch = jest.fn(); + }); + + it("should compose no middleware correctly", () => { + const response = new Response("success", { status: 200 }); + mockFetch.mockResolvedValue(response); + + const composedFetch = applyMiddlewares()(mockFetch); + + expect(composedFetch).toBe(mockFetch); + }); + + it("should compose single middleware correctly", async () => { + const response = new Response("success", { status: 200 }); + mockFetch.mockResolvedValue(response); + + // Create a middleware that adds a header + const middleware1 = + (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("X-Middleware-1", "applied"); + return next(input, { ...init, headers }); + }; + + const composedFetch = applyMiddlewares(middleware1)(mockFetch); + + await composedFetch("https://api.example.com/data"); + + expect(mockFetch).toHaveBeenCalledWith( + "https://api.example.com/data", + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("X-Middleware-1")).toBe("applied"); + }); + + it("should compose multiple middleware in order", async () => { + const response = new Response("success", { status: 200 }); + mockFetch.mockResolvedValue(response); + + // Create middleware that add identifying headers + const middleware1 = + (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("X-Middleware-1", "applied"); + return next(input, { ...init, headers }); + }; + + const middleware2 = + (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("X-Middleware-2", "applied"); + return next(input, { ...init, headers }); + }; + + const middleware3 = + (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("X-Middleware-3", "applied"); + return next(input, { ...init, headers }); + }; + + const composedFetch = applyMiddlewares( + middleware1, + middleware2, + middleware3, + )(mockFetch); + + await composedFetch("https://api.example.com/data"); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("X-Middleware-1")).toBe("applied"); + expect(headers.get("X-Middleware-2")).toBe("applied"); + expect(headers.get("X-Middleware-3")).toBe("applied"); + }); + + it("should work with real fetch middleware functions", async () => { + const response = new Response("success", { status: 200, statusText: "OK" }); + mockFetch.mockResolvedValue(response); + + // Create middleware that add identifying headers + const oauthMiddleware = + (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("Authorization", "Bearer test-token"); + return next(input, { ...init, headers }); + }; + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const composedFetch = applyMiddlewares( + oauthMiddleware, + withLogging({ logger: mockLogger, statusLevel: 0 }), + )(mockFetch); + + await composedFetch("https://api.example.com/data"); + + // Should have both Authorization header and logging + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBe("Bearer test-token"); + expect(mockLogger).toHaveBeenCalledWith({ + method: "GET", + url: "https://api.example.com/data", + status: 200, + statusText: "OK", + duration: expect.any(Number), + requestHeaders: undefined, + responseHeaders: undefined, + }); + }); + + it("should preserve error propagation through middleware", async () => { + const errorMiddleware = + (next: FetchLike) => async (input: string | URL, init?: RequestInit) => { + try { + return await next(input, init); + } catch (error) { + // Add context to the error + throw new Error( + `Middleware error: ${error instanceof Error ? error.message : String(error)}`, + ); + } + }; + + const originalError = new Error("Network failure"); + mockFetch.mockRejectedValue(originalError); + + const composedFetch = applyMiddlewares(errorMiddleware)(mockFetch); + + await expect(composedFetch("https://api.example.com/data")).rejects.toThrow( + "Middleware error: Network failure", + ); + }); +}); + +describe("Integration Tests", () => { + let mockProvider: jest.Mocked; + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + + mockProvider = { + get redirectUrl() { + return "http://localhost/callback"; + }, + get clientMetadata() { + return { redirect_uris: ["http://localhost/callback"] }; + }, + tokens: jest.fn(), + saveTokens: jest.fn(), + clientInformation: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + invalidateCredentials: jest.fn(), + }; + + mockFetch = jest.fn(); + }); + + it("should work with SSE transport pattern", async () => { + // Simulate how SSE transport might use the middleware + mockProvider.tokens.mockResolvedValue({ + access_token: "sse-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const response = new Response('{"jsonrpc":"2.0","id":1,"result":{}}', { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + mockFetch.mockResolvedValue(response); + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const enhancedFetch = applyMiddlewares( + withOAuth( + mockProvider as OAuthClientProvider, + "https://mcp-server.example.com/", + ), + withLogging({ logger: mockLogger, statusLevel: 400 }), // Only log errors + )(mockFetch); + + // Simulate SSE POST request + await enhancedFetch("https://mcp-server.example.com/endpoint", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "tools/list", + id: 1, + }), + }); + + expect(mockFetch).toHaveBeenCalledWith( + "https://mcp-server.example.com/endpoint", + expect.objectContaining({ + method: "POST", + headers: expect.any(Headers), + body: expect.any(String), + }), + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBe("Bearer sse-token"); + expect(headers.get("Content-Type")).toBe("application/json"); + }); + + it("should work with StreamableHTTP transport pattern", async () => { + // Simulate how StreamableHTTP transport might use the middleware + mockProvider.tokens.mockResolvedValue({ + access_token: "streamable-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const response = new Response(null, { + status: 202, + headers: { "mcp-session-id": "session-123" }, + }); + mockFetch.mockResolvedValue(response); + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const enhancedFetch = applyMiddlewares( + withOAuth( + mockProvider as OAuthClientProvider, + "https://streamable-server.example.com/", + ), + withLogging({ + logger: mockLogger, + includeResponseHeaders: true, + statusLevel: 0, + }), + )(mockFetch); + + // Simulate StreamableHTTP initialization request + await enhancedFetch("https://streamable-server.example.com/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "initialize", + params: { protocolVersion: "2025-03-26", clientInfo: { name: "test" } }, + id: 1, + }), + }); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBe("Bearer streamable-token"); + expect(headers.get("Accept")).toBe("application/json, text/event-stream"); + }); + + it("should handle auth retry in transport-like scenario", async () => { + mockProvider.tokens + .mockResolvedValueOnce({ + access_token: "expired-token", + token_type: "Bearer", + expires_in: 3600, + }) + .mockResolvedValueOnce({ + access_token: "fresh-token", + token_type: "Bearer", + expires_in: 3600, + }); + + const unauthorizedResponse = new Response('{"error":"invalid_token"}', { + status: 401, + headers: { "www-authenticate": 'Bearer realm="mcp"' }, + }); + const successResponse = new Response( + '{"jsonrpc":"2.0","id":1,"result":{}}', + { + status: 200, + }, + ); + + mockFetch + .mockResolvedValueOnce(unauthorizedResponse) + .mockResolvedValueOnce(successResponse); + + mockExtractResourceMetadataUrl.mockReturnValue( + new URL("https://auth.example.com/.well-known/oauth-protected-resource"), + ); + mockAuth.mockResolvedValue("AUTHORIZED"); + + // Use custom logger to avoid console output + const mockLogger = jest.fn(); + const enhancedFetch = applyMiddlewares( + withOAuth( + mockProvider as OAuthClientProvider, + "https://mcp-server.example.com/", + ), + withLogging({ logger: mockLogger, statusLevel: 0 }), + )(mockFetch); + + const result = await enhancedFetch( + "https://mcp-server.example.com/endpoint", + { + method: "POST", + body: JSON.stringify({ jsonrpc: "2.0", method: "test", id: 1 }), + }, + ); + + expect(result).toBe(successResponse); + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(mockAuth).toHaveBeenCalledWith(mockProvider, { + serverUrl: "https://mcp-server.example.com/", + resourceMetadataUrl: new URL( + "https://auth.example.com/.well-known/oauth-protected-resource", + ), + fetchFn: mockFetch, + }); + }); +}); + +describe("createMiddleware", () => { + let mockFetch: jest.MockedFunction; + + beforeEach(() => { + jest.clearAllMocks(); + mockFetch = jest.fn(); + }); + + it("should create middleware with cleaner syntax", async () => { + const response = new Response("success", { status: 200 }); + mockFetch.mockResolvedValue(response); + + const customMiddleware = createMiddleware(async (next, input, init) => { + const headers = new Headers(init?.headers); + headers.set("X-Custom-Header", "custom-value"); + return next(input, { ...init, headers }); + }); + + const enhancedFetch = customMiddleware(mockFetch); + await enhancedFetch("https://api.example.com/data"); + + expect(mockFetch).toHaveBeenCalledWith( + "https://api.example.com/data", + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("X-Custom-Header")).toBe("custom-value"); + }); + + it("should support conditional middleware logic", async () => { + const apiResponse = new Response("api response", { status: 200 }); + const publicResponse = new Response("public response", { status: 200 }); + mockFetch + .mockResolvedValueOnce(apiResponse) + .mockResolvedValueOnce(publicResponse); + + const conditionalMiddleware = createMiddleware( + async (next, input, init) => { + const url = typeof input === "string" ? input : input.toString(); + + if (url.includes("/api/")) { + const headers = new Headers(init?.headers); + headers.set("X-API-Version", "v2"); + return next(input, { ...init, headers }); + } + + return next(input, init); + }, + ); + + const enhancedFetch = conditionalMiddleware(mockFetch); + + // Test API route + await enhancedFetch("https://example.com/api/users"); + let callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("X-API-Version")).toBe("v2"); + + // Test non-API route + await enhancedFetch("https://example.com/public/page"); + callArgs = mockFetch.mock.calls[1]; + const maybeHeaders = callArgs[1]?.headers as Headers | undefined; + expect(maybeHeaders?.get("X-API-Version")).toBeUndefined(); + }); + + it("should support short-circuit responses", async () => { + const customMiddleware = createMiddleware(async (next, input, init) => { + const url = typeof input === "string" ? input : input.toString(); + + // Short-circuit for specific URL + if (url.includes("/cached")) { + return new Response("cached data", { status: 200 }); + } + + return next(input, init); + }); + + const enhancedFetch = customMiddleware(mockFetch); + + // Test cached route (should not call mockFetch) + const cachedResponse = await enhancedFetch( + "https://example.com/cached/data", + ); + expect(await cachedResponse.text()).toBe("cached data"); + expect(mockFetch).not.toHaveBeenCalled(); + + // Test normal route + mockFetch.mockResolvedValue(new Response("fresh data", { status: 200 })); + const normalResponse = await enhancedFetch("https://example.com/normal/data"); + expect(await normalResponse.text()).toBe("fresh data"); + expect(mockFetch).toHaveBeenCalledTimes(1); + }); + + it("should handle response transformation", async () => { + const originalResponse = new Response('{"data": "original"}', { + status: 200, + headers: { "Content-Type": "application/json" }, + }); + mockFetch.mockResolvedValue(originalResponse); + + const transformMiddleware = createMiddleware(async (next, input, init) => { + const response = await next(input, init); + + if (response.headers.get("content-type")?.includes("application/json")) { + const data = await response.json(); + const transformed = { ...data, timestamp: 123456789 }; + + return new Response(JSON.stringify(transformed), { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }); + } + + return response; + }); + + const enhancedFetch = transformMiddleware(mockFetch); + const response = await enhancedFetch("https://api.example.com/data"); + const result = await response.json(); + + expect(result).toEqual({ + data: "original", + timestamp: 123456789, + }); + }); + + it("should support error handling and recovery", async () => { + let attemptCount = 0; + mockFetch.mockImplementation(async () => { + attemptCount++; + if (attemptCount === 1) { + throw new Error("Network error"); + } + return new Response("success", { status: 200 }); + }); + + const retryMiddleware = createMiddleware(async (next, input, init) => { + try { + return await next(input, init); + } catch (error) { + // Retry once on network error + console.log("Retrying request after error:", error); + return await next(input, init); + } + }); + + const enhancedFetch = retryMiddleware(mockFetch); + const response = await enhancedFetch("https://api.example.com/data"); + + expect(await response.text()).toBe("success"); + expect(mockFetch).toHaveBeenCalledTimes(2); + }); + + it("should compose well with other middleware", async () => { + const response = new Response("success", { status: 200 }); + mockFetch.mockResolvedValue(response); + + // Create custom middleware using createMiddleware + const customAuth = createMiddleware(async (next, input, init) => { + const headers = new Headers(init?.headers); + headers.set("Authorization", "Custom token"); + return next(input, { ...init, headers }); + }); + + const customLogging = createMiddleware(async (next, input, init) => { + const url = typeof input === "string" ? input : input.toString(); + console.log(`Request to: ${url}`); + const response = await next(input, init); + console.log(`Response status: ${response.status}`); + return response; + }); + + // Compose with existing middleware + const enhancedFetch = applyMiddlewares( + customAuth, + customLogging, + withLogging({ statusLevel: 400 }), + )(mockFetch); + + await enhancedFetch("https://api.example.com/data"); + + const callArgs = mockFetch.mock.calls[0]; + const headers = callArgs[1]?.headers as Headers; + expect(headers.get("Authorization")).toBe("Custom token"); + }); + + it("should have access to both input types (string and URL)", async () => { + const response = new Response("success", { status: 200 }); + mockFetch.mockResolvedValue(response); + + let capturedInputType: string | undefined; + const inspectMiddleware = createMiddleware(async (next, input, init) => { + capturedInputType = typeof input === "string" ? "string" : "URL"; + return next(input, init); + }); + + const enhancedFetch = inspectMiddleware(mockFetch); + + // Test with string input + await enhancedFetch("https://api.example.com/data"); + expect(capturedInputType).toBe("string"); + + // Test with URL input + await enhancedFetch(new URL("https://api.example.com/data")); + expect(capturedInputType).toBe("URL"); + }); +}); diff --git a/src/client/middleware.ts b/src/client/middleware.ts new file mode 100644 index 000000000..3d0661584 --- /dev/null +++ b/src/client/middleware.ts @@ -0,0 +1,358 @@ +import { + auth, + extractResourceMetadataUrl, + OAuthClientProvider, + UnauthorizedError, +} from "./auth.js"; +import { FetchLike } from "../shared/transport.js"; + +/** + * Middleware function that wraps and enhances fetch functionality. + * Takes a fetch handler and returns an enhanced fetch handler. + */ +export type Middleware = (next: FetchLike) => FetchLike; + +/** + * Creates a fetch wrapper that handles OAuth authentication automatically. + * + * This wrapper will: + * - Add Authorization headers with access tokens + * - Handle 401 responses by attempting re-authentication + * - Retry the original request after successful auth + * - Handle OAuth errors appropriately (InvalidClientError, etc.) + * + * The baseUrl parameter is optional and defaults to using the domain from the request URL. + * However, you should explicitly provide baseUrl when: + * - Making requests to multiple subdomains (e.g., api.example.com, cdn.example.com) + * - Using API paths that differ from OAuth discovery paths (e.g., requesting /api/v1/data but OAuth is at /) + * - The OAuth server is on a different domain than your API requests + * - You want to ensure consistent OAuth behavior regardless of request URLs + * + * For MCP transports, set baseUrl to the same URL you pass to the transport constructor. + * + * Note: This wrapper is designed for general-purpose fetch operations. + * MCP transports (SSE and StreamableHTTP) already have built-in OAuth handling + * and should not need this wrapper. + * + * @param provider - OAuth client provider for authentication + * @param baseUrl - Base URL for OAuth server discovery (defaults to request URL domain) + * @returns A fetch middleware function + */ +export const withOAuth = + (provider: OAuthClientProvider, baseUrl?: string | URL): Middleware => + (next) => { + return async (input, init) => { + const makeRequest = async (): Promise => { + const headers = new Headers(init?.headers); + + // Add authorization header if tokens are available + const tokens = await provider.tokens(); + if (tokens) { + headers.set("Authorization", `Bearer ${tokens.access_token}`); + } + + return await next(input, { ...init, headers }); + }; + + let response = await makeRequest(); + + // Handle 401 responses by attempting re-authentication + if (response.status === 401) { + try { + const resourceMetadataUrl = extractResourceMetadataUrl(response); + + // Use provided baseUrl or extract from request URL + const serverUrl = + baseUrl || + (typeof input === "string" ? new URL(input).origin : input.origin); + + const result = await auth(provider, { + serverUrl, + resourceMetadataUrl, + fetchFn: next, + }); + + if (result === "REDIRECT") { + throw new UnauthorizedError( + "Authentication requires user authorization - redirect initiated", + ); + } + + if (result !== "AUTHORIZED") { + throw new UnauthorizedError( + `Authentication failed with result: ${result}`, + ); + } + + // Retry the request with fresh tokens + response = await makeRequest(); + } catch (error) { + if (error instanceof UnauthorizedError) { + throw error; + } + throw new UnauthorizedError( + `Failed to re-authenticate: ${error instanceof Error ? error.message : String(error)}`, + ); + } + } + + // If we still have a 401 after re-auth attempt, throw an error + if (response.status === 401) { + const url = typeof input === "string" ? input : input.toString(); + throw new UnauthorizedError(`Authentication failed for ${url}`); + } + + return response; + }; + }; + +/** + * Logger function type for HTTP requests + */ +export type RequestLogger = (input: { + method: string; + url: string | URL; + status: number; + statusText: string; + duration: number; + requestHeaders?: Headers; + responseHeaders?: Headers; + error?: Error; +}) => void; + +/** + * Configuration options for the logging middleware + */ +export type LoggingOptions = { + /** + * Custom logger function, defaults to console logging + */ + logger?: RequestLogger; + + /** + * Whether to include request headers in logs + * @default false + */ + includeRequestHeaders?: boolean; + + /** + * Whether to include response headers in logs + * @default false + */ + includeResponseHeaders?: boolean; + + /** + * Status level filter - only log requests with status >= this value + * Set to 0 to log all requests, 400 to log only errors + * @default 0 + */ + statusLevel?: number; +}; + +/** + * Creates a fetch middleware that logs HTTP requests and responses. + * + * When called without arguments `withLogging()`, it uses the default logger that: + * - Logs successful requests (2xx) to `console.log` + * - Logs error responses (4xx/5xx) and network errors to `console.error` + * - Logs all requests regardless of status (statusLevel: 0) + * - Does not include request or response headers in logs + * - Measures and displays request duration in milliseconds + * + * Important: the default logger uses both `console.log` and `console.error` so it should not be used with + * `stdio` transports and applications. + * + * @param options - Logging configuration options + * @returns A fetch middleware function + */ +export const withLogging = (options: LoggingOptions = {}): Middleware => { + const { + logger, + includeRequestHeaders = false, + includeResponseHeaders = false, + statusLevel = 0, + } = options; + + const defaultLogger: RequestLogger = (input) => { + const { + method, + url, + status, + statusText, + duration, + requestHeaders, + responseHeaders, + error, + } = input; + + let message = error + ? `HTTP ${method} ${url} failed: ${error.message} (${duration}ms)` + : `HTTP ${method} ${url} ${status} ${statusText} (${duration}ms)`; + + // Add headers to message if requested + if (includeRequestHeaders && requestHeaders) { + const reqHeaders = Array.from(requestHeaders.entries()) + .map(([key, value]) => `${key}: ${value}`) + .join(", "); + message += `\n Request Headers: {${reqHeaders}}`; + } + + if (includeResponseHeaders && responseHeaders) { + const resHeaders = Array.from(responseHeaders.entries()) + .map(([key, value]) => `${key}: ${value}`) + .join(", "); + message += `\n Response Headers: {${resHeaders}}`; + } + + if (error || status >= 400) { + // eslint-disable-next-line no-console + console.error(message); + } else { + // eslint-disable-next-line no-console + console.log(message); + } + }; + + const logFn = logger || defaultLogger; + + return (next) => async (input, init) => { + const startTime = performance.now(); + const method = init?.method || "GET"; + const url = typeof input === "string" ? input : input.toString(); + const requestHeaders = includeRequestHeaders + ? new Headers(init?.headers) + : undefined; + + try { + const response = await next(input, init); + const duration = performance.now() - startTime; + + // Only log if status meets the log level threshold + if (response.status >= statusLevel) { + logFn({ + method, + url, + status: response.status, + statusText: response.statusText, + duration, + requestHeaders, + responseHeaders: includeResponseHeaders + ? response.headers + : undefined, + }); + } + + return response; + } catch (error) { + const duration = performance.now() - startTime; + + // Always log errors regardless of log level + logFn({ + method, + url, + status: 0, + statusText: "Network Error", + duration, + requestHeaders, + error: error as Error, + }); + + throw error; + } + }; +}; + +/** + * Composes multiple fetch middleware functions into a single middleware pipeline. + * Middleware are applied in the order they appear, creating a chain of handlers. + * + * @example + * ```typescript + * // Create a middleware pipeline that handles both OAuth and logging + * const enhancedFetch = applyMiddlewares( + * withOAuth(oauthProvider, 'https://api.example.com/'), + * withLogging({ statusLevel: 400 }) + * )(fetch); + * + * // Use the enhanced fetch - it will handle auth and log errors + * const response = await enhancedFetch('https://api.example.com/data'); + * ``` + * + * @param middleware - Array of fetch middleware to compose into a pipeline + * @returns A single composed middleware function + */ +export const applyMiddlewares = ( + ...middleware: Middleware[] +): Middleware => { + return (next) => { + return middleware.reduce((handler, mw) => mw(handler), next); + }; +}; + +/** + * Helper function to create custom fetch middleware with cleaner syntax. + * Provides the next handler and request details as separate parameters for easier access. + * + * @example + * ```typescript + * // Create custom authentication middleware + * const customAuthMiddleware = createMiddleware(async (next, input, init) => { + * const headers = new Headers(init?.headers); + * headers.set('X-Custom-Auth', 'my-token'); + * + * const response = await next(input, { ...init, headers }); + * + * if (response.status === 401) { + * console.log('Authentication failed'); + * } + * + * return response; + * }); + * + * // Create conditional middleware + * const conditionalMiddleware = createMiddleware(async (next, input, init) => { + * const url = typeof input === 'string' ? input : input.toString(); + * + * // Only add headers for API routes + * if (url.includes('/api/')) { + * const headers = new Headers(init?.headers); + * headers.set('X-API-Version', 'v2'); + * return next(input, { ...init, headers }); + * } + * + * // Pass through for non-API routes + * return next(input, init); + * }); + * + * // Create caching middleware + * const cacheMiddleware = createMiddleware(async (next, input, init) => { + * const cacheKey = typeof input === 'string' ? input : input.toString(); + * + * // Check cache first + * const cached = await getFromCache(cacheKey); + * if (cached) { + * return new Response(cached, { status: 200 }); + * } + * + * // Make request and cache result + * const response = await next(input, init); + * if (response.ok) { + * await saveToCache(cacheKey, await response.clone().text()); + * } + * + * return response; + * }); + * ``` + * + * @param handler - Function that receives the next handler and request parameters + * @returns A fetch middleware function + */ +export const createMiddleware = ( + handler: ( + next: FetchLike, + input: string | URL, + init?: RequestInit, + ) => Promise, +): Middleware => { + return (next) => (input, init) => handler(next, input as string | URL, init); +}; From 64f7cdd09bf99031510b52a1e957e40febea997b Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 19 Aug 2025 15:39:21 +0100 Subject: [PATCH 2/4] restrict url schemes allowed in oauth metadata (#877) Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Claude --- src/shared/auth.test.ts | 116 ++++++++++++++++++++++++++++++++++++++++ src/shared/auth.ts | 59 +++++++++++++------- 2 files changed, 157 insertions(+), 18 deletions(-) create mode 100644 src/shared/auth.test.ts diff --git a/src/shared/auth.test.ts b/src/shared/auth.test.ts new file mode 100644 index 000000000..c1ed82ba2 --- /dev/null +++ b/src/shared/auth.test.ts @@ -0,0 +1,116 @@ +import { describe, it, expect } from '@jest/globals'; +import { + SafeUrlSchema, + OAuthMetadataSchema, + OpenIdProviderMetadataSchema, + OAuthClientMetadataSchema, +} from './auth.js'; + +describe('SafeUrlSchema', () => { + it('accepts valid HTTPS URLs', () => { + expect(SafeUrlSchema.parse('https://example.com/')).toBe('https://example.com/'); + expect(SafeUrlSchema.parse('https://auth.example.com/oauth/authorize')).toBe('https://auth.example.com/oauth/authorize'); + }); + + it('accepts valid HTTP URLs', () => { + expect(SafeUrlSchema.parse('http://localhost:3000/')).toBe('http://localhost:3000/'); + }); + + it('rejects javascript: scheme URLs', () => { + expect(() => SafeUrlSchema.parse('javascript:alert(1)')).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + expect(() => SafeUrlSchema.parse('JAVASCRIPT:alert(1)')).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); + + it('rejects invalid URLs', () => { + expect(() => SafeUrlSchema.parse('not-a-url')).toThrow(); + expect(() => SafeUrlSchema.parse('')).toThrow(); + }); + + it('works with safeParse', () => { + expect(() => SafeUrlSchema.safeParse('not-a-url')).not.toThrow(); + }); +}); + +describe('OAuthMetadataSchema', () => { + it('validates complete OAuth metadata', () => { + const metadata = { + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/oauth/authorize', + token_endpoint: 'https://auth.example.com/oauth/token', + response_types_supported: ['code'], + scopes_supported: ['read', 'write'], + }; + + expect(() => OAuthMetadataSchema.parse(metadata)).not.toThrow(); + }); + + it('rejects metadata with javascript: URLs', () => { + const metadata = { + issuer: 'https://auth.example.com/', + authorization_endpoint: 'javascript:alert(1)', + token_endpoint: 'https://auth.example.com/oauth/token', + response_types_supported: ['code'], + }; + + expect(() => OAuthMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); + + it('requires mandatory fields', () => { + const incompleteMetadata = { + issuer: 'https://auth.example.com/', + }; + + expect(() => OAuthMetadataSchema.parse(incompleteMetadata)).toThrow(); + }); +}); + +describe('OpenIdProviderMetadataSchema', () => { + it('validates complete OpenID Provider metadata', () => { + const metadata = { + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/oauth/authorize', + token_endpoint: 'https://auth.example.com/oauth/token', + jwks_uri: 'https://auth.example.com/.well-known/jwks.json', + response_types_supported: ['code'], + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + }; + + expect(() => OpenIdProviderMetadataSchema.parse(metadata)).not.toThrow(); + }); + + it('rejects metadata with javascript: in jwks_uri', () => { + const metadata = { + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/oauth/authorize', + token_endpoint: 'https://auth.example.com/oauth/token', + jwks_uri: 'javascript:alert(1)', + response_types_supported: ['code'], + subject_types_supported: ['public'], + id_token_signing_alg_values_supported: ['RS256'], + }; + + expect(() => OpenIdProviderMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); +}); + +describe('OAuthClientMetadataSchema', () => { + it('validates client metadata with safe URLs', () => { + const metadata = { + redirect_uris: ['https://app.example.com/callback'], + client_name: 'Test App', + client_uri: 'https://app.example.com/', + }; + + expect(() => OAuthClientMetadataSchema.parse(metadata)).not.toThrow(); + }); + + it('rejects client metadata with javascript: redirect URIs', () => { + const metadata = { + redirect_uris: ['javascript:alert(1)'], + client_name: 'Test App', + }; + + expect(() => OAuthClientMetadataSchema.parse(metadata)).toThrow('URL cannot use javascript:, data:, or vbscript: scheme'); + }); +}); diff --git a/src/shared/auth.ts b/src/shared/auth.ts index 47eba9ac5..886eb1084 100644 --- a/src/shared/auth.ts +++ b/src/shared/auth.ts @@ -1,12 +1,35 @@ import { z } from "zod"; +/** + * Reusable URL validation that disallows javascript: scheme + */ +export const SafeUrlSchema = z.string().url() + .superRefine((val, ctx) => { + if (!URL.canParse(val)) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: "URL must be parseable", + fatal: true, + }); + + return z.NEVER; + } + }).refine( + (url) => { + const u = new URL(url); + return u.protocol !== 'javascript:' && u.protocol !== 'data:' && u.protocol !== 'vbscript:'; + }, + { message: "URL cannot use javascript:, data:, or vbscript: scheme" } +); + + /** * RFC 9728 OAuth Protected Resource Metadata */ export const OAuthProtectedResourceMetadataSchema = z .object({ resource: z.string().url(), - authorization_servers: z.array(z.string().url()).optional(), + authorization_servers: z.array(SafeUrlSchema).optional(), jwks_uri: z.string().url().optional(), scopes_supported: z.array(z.string()).optional(), bearer_methods_supported: z.array(z.string()).optional(), @@ -28,9 +51,9 @@ export const OAuthProtectedResourceMetadataSchema = z export const OAuthMetadataSchema = z .object({ issuer: z.string(), - authorization_endpoint: z.string(), - token_endpoint: z.string(), - registration_endpoint: z.string().optional(), + authorization_endpoint: SafeUrlSchema, + token_endpoint: SafeUrlSchema, + registration_endpoint: SafeUrlSchema.optional(), scopes_supported: z.array(z.string()).optional(), response_types_supported: z.array(z.string()), response_modes_supported: z.array(z.string()).optional(), @@ -39,8 +62,8 @@ export const OAuthMetadataSchema = z token_endpoint_auth_signing_alg_values_supported: z .array(z.string()) .optional(), - service_documentation: z.string().optional(), - revocation_endpoint: z.string().optional(), + service_documentation: SafeUrlSchema.optional(), + revocation_endpoint: SafeUrlSchema.optional(), revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), revocation_endpoint_auth_signing_alg_values_supported: z .array(z.string()) @@ -63,11 +86,11 @@ export const OAuthMetadataSchema = z export const OpenIdProviderMetadataSchema = z .object({ issuer: z.string(), - authorization_endpoint: z.string(), - token_endpoint: z.string(), - userinfo_endpoint: z.string().optional(), - jwks_uri: z.string(), - registration_endpoint: z.string().optional(), + authorization_endpoint: SafeUrlSchema, + token_endpoint: SafeUrlSchema, + userinfo_endpoint: SafeUrlSchema.optional(), + jwks_uri: SafeUrlSchema, + registration_endpoint: SafeUrlSchema.optional(), scopes_supported: z.array(z.string()).optional(), response_types_supported: z.array(z.string()), response_modes_supported: z.array(z.string()).optional(), @@ -101,8 +124,8 @@ export const OpenIdProviderMetadataSchema = z request_parameter_supported: z.boolean().optional(), request_uri_parameter_supported: z.boolean().optional(), require_request_uri_registration: z.boolean().optional(), - op_policy_uri: z.string().optional(), - op_tos_uri: z.string().optional(), + op_policy_uri: SafeUrlSchema.optional(), + op_tos_uri: SafeUrlSchema.optional(), }) .passthrough(); @@ -146,18 +169,18 @@ export const OAuthErrorResponseSchema = z * RFC 7591 OAuth 2.0 Dynamic Client Registration metadata */ export const OAuthClientMetadataSchema = z.object({ - redirect_uris: z.array(z.string()).refine((uris) => uris.every((uri) => URL.canParse(uri)), { message: "redirect_uris must contain valid URLs" }), + redirect_uris: z.array(SafeUrlSchema), token_endpoint_auth_method: z.string().optional(), grant_types: z.array(z.string()).optional(), response_types: z.array(z.string()).optional(), client_name: z.string().optional(), - client_uri: z.string().optional(), - logo_uri: z.string().optional(), + client_uri: SafeUrlSchema.optional(), + logo_uri: SafeUrlSchema.optional(), scope: z.string().optional(), contacts: z.array(z.string()).optional(), - tos_uri: z.string().optional(), + tos_uri: SafeUrlSchema.optional(), policy_uri: z.string().optional(), - jwks_uri: z.string().optional(), + jwks_uri: SafeUrlSchema.optional(), jwks: z.any().optional(), software_id: z.string().optional(), software_version: z.string().optional(), From 1f5950be01927cb8e4e555cb5584c44db1604f39 Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 19 Aug 2025 15:43:09 +0100 Subject: [PATCH 3/4] [auth] OAuth protected-resource-metadata: fallback on 4xx not just 404 (#879) Co-authored-by: adam jones --- src/client/auth.test.ts | 22 ++++++++-- src/client/auth.ts | 2 +- src/client/streamableHttp.test.ts | 70 +++++++++++++++++-------------- 3 files changed, 57 insertions(+), 37 deletions(-) diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index fb9b31006..f28163d14 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -212,11 +212,11 @@ describe("OAuth Authorization", () => { expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource/path?param=value"); }); - it("falls back to root discovery when path-aware discovery returns 404", async () => { - // First call (path-aware) returns 404 + it.each([400, 401, 403, 404, 410, 422, 429])("falls back to root discovery when path-aware discovery returns %d", async (statusCode) => { + // First call (path-aware) returns 4xx mockFetch.mockResolvedValueOnce({ ok: false, - status: 404, + status: statusCode, }); // Second call (root fallback) succeeds @@ -267,6 +267,20 @@ describe("OAuth Authorization", () => { expect(calls.length).toBe(2); }); + it("throws error on 500 status and does not fallback", async () => { + // First call (path-aware) returns 500 + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + }); + + await expect(discoverOAuthProtectedResourceMetadata("https://resource.example.com/path/name")) + .rejects.toThrow(); + + const calls = mockFetch.mock.calls; + expect(calls.length).toBe(1); // Should not attempt fallback + }); + it("does not fallback when the original URL is already at root path", async () => { // First call (path-aware for root) returns 404 mockFetch.mockResolvedValueOnce({ @@ -907,7 +921,7 @@ describe("OAuth Authorization", () => { const metadata = await discoverAuthorizationServerMetadata("https://auth.example.com/tenant1"); expect(metadata).toBeUndefined(); - + // Verify that all discovery URLs were attempted expect(mockFetch).toHaveBeenCalledTimes(8); // 4 URLs × 2 attempts each (with and without headers) }); diff --git a/src/client/auth.ts b/src/client/auth.ts index 8ac9ddd1e..fcc320f17 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -571,7 +571,7 @@ async function tryMetadataDiscovery( * Determines if fallback to root discovery should be attempted */ function shouldAttemptFallback(response: Response | undefined, pathname: string): boolean { - return !response || response.status === 404 && pathname !== '/'; + return !response || (response.status >= 400 && response.status < 500) && pathname !== '/'; } /** diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 88fd48017..fdd35ed3f 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -465,7 +465,7 @@ describe("StreamableHTTPClientTransport", () => { // Verify custom fetch was used expect(customFetch).toHaveBeenCalled(); - + // Global fetch should never have been called expect(global.fetch).not.toHaveBeenCalled(); }); @@ -589,32 +589,32 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); - + describe('Reconnection Logic', () => { let transport: StreamableHTTPClientTransport; - + // Use fake timers to control setTimeout and make the test instant. beforeEach(() => jest.useFakeTimers()); afterEach(() => jest.useRealTimers()); - + it('should reconnect a GET-initiated notification stream that fails', async () => { // ARRANGE transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { reconnectionOptions: { - initialReconnectionDelay: 10, - maxRetries: 1, + initialReconnectionDelay: 10, + maxRetries: 1, maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity } }); - + const errorSpy = jest.fn(); transport.onerror = errorSpy; - + const failingStream = new ReadableStream({ start(controller) { controller.error(new Error("Network failure")); } }); - + const fetchMock = global.fetch as jest.Mock; // Mock the initial GET request, which will fail. fetchMock.mockResolvedValueOnce({ @@ -628,13 +628,13 @@ describe("StreamableHTTPClientTransport", () => { headers: new Headers({ "content-type": "text/event-stream" }), body: new ReadableStream(), }); - + // ACT await transport.start(); // Trigger the GET stream directly using the internal method for a clean test. await transport["_startOrAuthSse"]({}); await jest.advanceTimersByTimeAsync(20); // Trigger reconnection timeout - + // ASSERT expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), @@ -644,25 +644,25 @@ describe("StreamableHTTPClientTransport", () => { expect(fetchMock.mock.calls[0][1]?.method).toBe('GET'); expect(fetchMock.mock.calls[1][1]?.method).toBe('GET'); }); - + it('should NOT reconnect a POST-initiated stream that fails', async () => { // ARRANGE transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { - reconnectionOptions: { - initialReconnectionDelay: 10, - maxRetries: 1, + reconnectionOptions: { + initialReconnectionDelay: 10, + maxRetries: 1, maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity } }); - + const errorSpy = jest.fn(); transport.onerror = errorSpy; - + const failingStream = new ReadableStream({ start(controller) { controller.error(new Error("Network failure")); } }); - + const fetchMock = global.fetch as jest.Mock; // Mock the POST request. It returns a streaming content-type but a failing body. fetchMock.mockResolvedValueOnce({ @@ -670,7 +670,7 @@ describe("StreamableHTTPClientTransport", () => { headers: new Headers({ "content-type": "text/event-stream" }), body: failingStream, }); - + // A dummy request message to trigger the `send` logic. const requestMessage: JSONRPCRequest = { jsonrpc: '2.0', @@ -678,13 +678,13 @@ describe("StreamableHTTPClientTransport", () => { id: 'request-1', params: {}, }; - + // ACT await transport.start(); // Use the public `send` method to initiate a POST that gets a stream response. await transport.send(requestMessage); await jest.advanceTimersByTimeAsync(20); // Advance time to check for reconnections - + // ASSERT expect(errorSpy).toHaveBeenCalledWith(expect.objectContaining({ message: expect.stringContaining('SSE stream disconnected: Error: Network failure'), @@ -718,7 +718,9 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock) // Initial connection .mockResolvedValueOnce(unauthedResponse) - // Resource discovery + // Resource discovery, path aware + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, root .mockResolvedValueOnce(unauthedResponse) // OAuth metadata discovery .mockResolvedValueOnce({ @@ -770,7 +772,9 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock) // Initial connection .mockResolvedValueOnce(unauthedResponse) - // Resource discovery + // Resource discovery, path aware + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, root .mockResolvedValueOnce(unauthedResponse) // OAuth metadata discovery .mockResolvedValueOnce({ @@ -822,7 +826,9 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock) // Initial connection .mockResolvedValueOnce(unauthedResponse) - // Resource discovery + // Resource discovery, path aware + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery, root .mockResolvedValueOnce(unauthedResponse) // OAuth metadata discovery .mockResolvedValueOnce({ @@ -888,7 +894,7 @@ describe("StreamableHTTPClientTransport", () => { ok: false, status: 404 }); - + // Create transport instance transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { authProvider: mockAuthProvider, @@ -901,14 +907,14 @@ describe("StreamableHTTPClientTransport", () => { // Verify custom fetch was used expect(customFetch).toHaveBeenCalled(); - + // Verify specific OAuth endpoints were called with custom fetch const customFetchCalls = customFetch.mock.calls; const callUrls = customFetchCalls.map(([url]) => url.toString()); - + // Should have called resource metadata discovery expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - + // Should have called OAuth authorization server metadata discovery expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); @@ -966,19 +972,19 @@ describe("StreamableHTTPClientTransport", () => { // Verify custom fetch was used expect(customFetch).toHaveBeenCalled(); - + // Verify specific OAuth endpoints were called with custom fetch const customFetchCalls = customFetch.mock.calls; const callUrls = customFetchCalls.map(([url]) => url.toString()); - + // Should have called resource metadata discovery expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); - + // Should have called OAuth authorization server metadata discovery expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); // Should have called token endpoint for authorization code exchange - const tokenCalls = customFetchCalls.filter(([url, options]) => + const tokenCalls = customFetchCalls.filter(([url, options]) => url.toString().includes('/token') && options?.method === "POST" ); expect(tokenCalls.length).toBeGreaterThan(0); From 3bc2235d747c320dfa0b6227cc84414c6d0add89 Mon Sep 17 00:00:00 2001 From: Felix Weinberger <3823880+felixweinberger@users.noreply.github.com> Date: Thu, 21 Aug 2025 17:08:09 +0100 Subject: [PATCH 4/4] chore: bump version to 1.17.4 (#894) --- package-lock.json | 4 ++-- package.json | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/package-lock.json b/package-lock.json index 1e0b12ed7..8759a701e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.17.3", + "version": "1.17.4", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.17.3", + "version": "1.17.4", "license": "MIT", "dependencies": { "ajv": "^6.12.6", diff --git a/package.json b/package.json index 697b051be..8be8f1002 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.17.3", + "version": "1.17.4", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)",