diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a4cd0ac0..1fe659cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: coverage: runs-on: ubuntu-latest - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' + if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' || (github.event_name == 'push' && github.ref == 'refs/heads/main') permissions: contents: read pull-requests: write @@ -38,6 +38,7 @@ jobs: with: name: code-coverage path: coverage.txt + retention-days: 30 - name: Generate coverage report uses: fgrosse/go-coverage-report@v1.2.0 diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index 31c70887..663a8174 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -525,9 +525,7 @@ func TestSSE(t *testing.T) { t.Fatalf("SendRequest failed: %v", err) } - if response == nil { - t.Fatal("Expected response, got nil") - } + require.NotNil(t, response, "Expected response, got nil") // Verify the response var result string diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index cdd5a93e..2b72f42c 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" ) // startMockStreamableHTTPServer starts a test HTTP server that implements @@ -535,9 +536,7 @@ func TestStreamableHTTP(t *testing.T) { t.Fatalf("SendRequest failed: %v", err) } - if response == nil { - t.Fatal("Expected response, got nil") - } + require.NotNil(t, response, "Expected response, got nil") // Verify the response var result string diff --git a/server/streamable_http.go b/server/streamable_http.go index 5a596467..4535943d 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -52,13 +52,13 @@ func WithStateLess(stateLess bool) StreamableHTTPOption { } // WithSessionIdManager sets a custom session id generator for the server. -// By default, the server uses InsecureStatefulSessionIdManager (UUID-based; insecure). +// By default, the server uses StatelessGeneratingSessionIdManager (generates IDs but no local validation). // Note: Options are applied in order; the last one wins. If combined with // WithStateLess or WithSessionIdManagerResolver, whichever is applied last takes effect. func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { return func(s *StreamableHTTPServer) { if manager == nil { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) return } s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(manager) @@ -72,13 +72,23 @@ func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { func WithSessionIdManagerResolver(resolver SessionIdManagerResolver) StreamableHTTPOption { return func(s *StreamableHTTPServer) { if resolver == nil { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) return } s.sessionIdManagerResolver = resolver } } +// WithStateful enables stateful session management using InsecureStatefulSessionIdManager. +// This requires sticky sessions in multi-instance deployments. +func WithStateful(stateful bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if stateful { + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + } + } +} + // WithHeartbeatInterval sets the heartbeat interval. Positive interval means the // server will send a heartbeat to the client through the GET connection, to keep // the connection alive from being closed by the network infrastructure (e.g. @@ -187,7 +197,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S sessionTools: newSessionToolsStore(), sessionLogLevels: newSessionLogLevelsStore(), endpointPath: "/mcp", - sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}), + sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&StatelessGeneratingSessionIdManager{}), logger: util.DefaultLogger(), sessionResources: newSessionResourcesStore(), sessionResourceTemplates: newSessionResourceTemplatesStore(), @@ -976,6 +986,8 @@ type streamableHttpSession struct { resourceTemplates *sessionResourceTemplatesStore upgradeToSSE atomic.Bool logLevels *sessionLogLevelsStore + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities // Sampling support for bidirectional communication samplingRequestChan chan samplingRequestItem // server -> client sampling requests @@ -1053,11 +1065,38 @@ func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string s.resourceTemplates.set(s.sessionID, templates) } +func (s *streamableHttpSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *streamableHttpSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *streamableHttpSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *streamableHttpSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + var ( _ SessionWithTools = (*streamableHttpSession)(nil) _ SessionWithResources = (*streamableHttpSession)(nil) _ SessionWithResourceTemplates = (*streamableHttpSession)(nil) _ SessionWithLogging = (*streamableHttpSession)(nil) + _ SessionWithClientInfo = (*streamableHttpSession)(nil) ) func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { @@ -1244,7 +1283,7 @@ type DefaultSessionIdManagerResolver struct { // NewDefaultSessionIdManagerResolver creates a new DefaultSessionIdManagerResolver with the given SessionIdManager func NewDefaultSessionIdManagerResolver(manager SessionIdManager) *DefaultSessionIdManagerResolver { if manager == nil { - manager = &InsecureStatefulSessionIdManager{} + manager = &StatelessSessionIdManager{} } return &DefaultSessionIdManagerResolver{manager: manager} } @@ -1270,6 +1309,30 @@ func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bo return false, nil } +// StatelessGeneratingSessionIdManager generates session IDs but doesn't validate them locally. +// This allows session IDs to be generated for clients while working across multiple instances. +type StatelessGeneratingSessionIdManager struct{} + +func (s *StatelessGeneratingSessionIdManager) Generate() string { + return idPrefix + uuid.New().String() +} + +func (s *StatelessGeneratingSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // Only validate format, not existence - allows cross-instance operation + if !strings.HasPrefix(sessionID, idPrefix) { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + return false, nil +} + +func (s *StatelessGeneratingSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + // No-op termination since we don't track sessions + return false, nil +} + // InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions. // It validates both format and existence of session IDs. // For more secure session id, use a more complex generator, like a JWT. diff --git a/server/streamable_http_client_info_test.go b/server/streamable_http_client_info_test.go new file mode 100644 index 00000000..866721fa --- /dev/null +++ b/server/streamable_http_client_info_test.go @@ -0,0 +1,60 @@ +package server + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestStreamableHttpSessionImplementsSessionWithClientInfo(t *testing.T) { + // Create the session stores + toolStore := newSessionToolsStore() + resourceStore := newSessionResourcesStore() + templatesStore := newSessionResourceTemplatesStore() + logStore := newSessionLogLevelsStore() + + // Create a streamable HTTP session + session := newStreamableHttpSession("test-session", toolStore, resourceStore, templatesStore, logStore) + + // Verify it implements SessionWithClientInfo + var clientSession ClientSession = session + clientInfoSession, ok := clientSession.(SessionWithClientInfo) + if !ok { + t.Fatal("streamableHttpSession should implement SessionWithClientInfo") + } + + // Test GetClientInfo with no data set (should return empty) + clientInfo := clientInfoSession.GetClientInfo() + if clientInfo.Name != "" || clientInfo.Version != "" { + t.Errorf("expected empty client info, got %+v", clientInfo) + } + + // Test SetClientInfo and GetClientInfo + expectedClientInfo := mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + clientInfoSession.SetClientInfo(expectedClientInfo) + + actualClientInfo := clientInfoSession.GetClientInfo() + if actualClientInfo.Name != expectedClientInfo.Name || actualClientInfo.Version != expectedClientInfo.Version { + t.Errorf("expected client info %+v, got %+v", expectedClientInfo, actualClientInfo) + } + + // Test GetClientCapabilities with no data set (should return empty) + capabilities := clientInfoSession.GetClientCapabilities() + if capabilities.Sampling != nil || capabilities.Roots != nil { + t.Errorf("expected empty client capabilities, got %+v", capabilities) + } + + // Test SetClientCapabilities and GetClientCapabilities + expectedCapabilities := mcp.ClientCapabilities{ + Sampling: &struct{}{}, + } + clientInfoSession.SetClientCapabilities(expectedCapabilities) + + actualCapabilities := clientInfoSession.GetClientCapabilities() + if actualCapabilities.Sampling == nil { + t.Errorf("expected sampling capability to be set") + } +} diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index f83a95eb..9e444c53 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -125,7 +125,7 @@ func TestStreamableHTTP_POST_InvalidContent(t *testing.T) { func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { mcpServer := NewMCPServer("test-mcp-server", "1.0") addSSETool(mcpServer) - server := NewTestStreamableHTTPServer(mcpServer) + server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) var sessionID string t.Run("initialize", func(t *testing.T) { @@ -595,6 +595,7 @@ func TestStreamableHttpResourceGet(t *testing.T) { testServer := NewTestStreamableHTTPServer( s, + WithStateful(true), WithHTTPContextFunc(func(ctx context.Context, r *http.Request) context.Context { session := ClientSessionFromContext(ctx) @@ -1014,7 +1015,7 @@ func TestStreamableHTTP_SessionWithLogging(t *testing.T) { }) mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks), WithLogging()) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) defer testServer.Close() // obtain a valid session ID first @@ -1404,7 +1405,7 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) { server := NewTestStreamableHTTPServer(mcpServer) defer server.Close() - t.Run("Reject tool call with fake session ID", func(t *testing.T) { + t.Run("Accept tool call with properly formatted session ID", func(t *testing.T) { toolCallRequest := map[string]any{ "jsonrpc": "2.0", "id": 1, @@ -1425,13 +1426,29 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) { } defer resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "Invalid session ID") { - t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body)) + var response map[string]any + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if result, ok := response["result"].(map[string]any); ok { + if content, ok := result["content"].([]any); ok && len(content) > 0 { + if textContent, ok := content[0].(map[string]any); ok { + if text, ok := textContent["text"].(string); ok { + // Should be a valid timestamp response + if text == "" { + t.Error("Expected non-empty timestamp response") + } + } + } + } + } else { + t.Errorf("Expected result in response, got: %s", string(body)) } }) @@ -1508,22 +1525,45 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) { } }) - t.Run("Reject tool call with terminated session ID", func(t *testing.T) { + t.Run("Reject tool call with terminated session ID (stateful mode)", func(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + // Use explicit stateful mode for this test since termination requires local tracking + server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) + defer server.Close() + + // First, initialize a session + initRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{ + "tools": map[string]any{}, + }, + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + jsonBody, _ := json.Marshal(initRequest) req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) req.Header.Set("Content-Type", "application/json") resp, err := server.Client().Do(req) if err != nil { - t.Fatalf("Failed to initialize: %v", err) + t.Fatalf("Failed to initialize session: %v", err) } - resp.Body.Close() sessionID := resp.Header.Get(HeaderKeySessionID) if sessionID == "" { t.Fatal("Expected session ID in response header") } + resp.Body.Close() + // Now terminate the session req, _ = http.NewRequest(http.MethodDelete, server.URL, nil) req.Header.Set(HeaderKeySessionID, sessionID) @@ -1780,13 +1820,19 @@ func TestDefaultSessionIdManagerResolver(t *testing.T) { t.Error("Expected resolver to return a non-nil manager") } - // Test that the resolved manager works (generates valid session IDs) + // Test that the resolved manager works (stateless behavior) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected default manager to generate non-empty session ID") + if sessionID != "" { + t.Error("Expected stateless manager to generate empty session ID") } - if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected default manager to generate session ID with correct prefix") + + // Test that validation accepts any session ID (stateless behavior) + isTerminated, err := resolved.Validate("any-session-id") + if err != nil { + t.Errorf("Expected stateless manager to accept any session ID, got error: %v", err) + } + if isTerminated { + t.Error("Expected stateless manager to not terminate sessions") } }) } @@ -1865,17 +1911,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { server := NewStreamableHTTPServer(mcpServer, WithStateLess(false)) - // Test that the default manager is still used (InsecureStatefulSessionIdManager) + // Test that the default manager is still used (StatelessGeneratingSessionIdManager) req, _ := http.NewRequest("POST", "/test", nil) resolved := server.sessionIdManagerResolver.ResolveSessionIdManager(req) - // Verify it's NOT a stateless manager + // Verify it's a generating manager (default behavior) sessionID := resolved.Generate() if sessionID == "" { - t.Error("Expected stateful manager when WithStateLess(false)") + t.Error("Expected generating manager to generate session ID by default") } if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected stateful session ID format") + t.Error("Expected generating manager to generate session ID with correct prefix") } }) @@ -1929,7 +1975,7 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Run("WithSessionIdManagerResolver handles nil resolver defensively", func(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") - // This should not panic and should fall back to default behavior + // This should not panic and should fall back to StatelessSessionIdManager (safe default) server := NewStreamableHTTPServer(mcpServer, WithSessionIdManagerResolver(nil)) req, _ := http.NewRequest("POST", "/test", nil) @@ -1938,20 +1984,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Error("Expected nil resolver to be replaced with default") } - // Test that the resolved manager works (should be default stateful manager) + // Test that the resolved manager works (should be default stateless manager) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected default manager to generate non-empty session ID") - } - if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected default manager to generate session ID with correct prefix") + if sessionID != "" { + t.Error("Expected default stateless manager to generate empty session ID") } }) t.Run("WithSessionIdManager handles nil manager defensively", func(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") - // This should not panic and should fall back to default behavior + // This should not panic and should fall back to StatelessSessionIdManager (safe default) server := NewStreamableHTTPServer(mcpServer, WithSessionIdManager(nil)) req, _ := http.NewRequest("POST", "/test", nil) @@ -1960,20 +2003,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Error("Expected nil manager to be replaced with default") } - // Test that the resolved manager works (should be default stateful manager) + // Test that the resolved manager works (should be default stateless manager) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected default manager to generate non-empty session ID") - } - if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected default manager to generate session ID with correct prefix") + if sessionID != "" { + t.Error("Expected default stateless manager to generate empty session ID") } }) t.Run("Multiple nil options fall back safely", func(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") - // Chain multiple nil options - last one should win with safe fallback + // Chain multiple nil options - last one should win with StatelessSessionIdManager fallback server := NewStreamableHTTPServer(mcpServer, WithSessionIdManager(nil), WithSessionIdManagerResolver(nil), @@ -1985,10 +2025,10 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Error("Expected chained nil options to fall back safely") } - // Verify it generates valid session IDs + // Verify it uses stateless behavior (default) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected fallback manager to generate non-empty session ID") + if sessionID != "" { + t.Error("Expected fallback stateless manager to generate empty session ID") } }) @@ -2021,6 +2061,28 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { } _ = resp.Body.Close() }) + + t.Run("WithStateful enables stateful manager", func(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + server := NewStreamableHTTPServer(mcpServer, WithStateful(true)) + + req, _ := http.NewRequest("POST", "/test", nil) + resolved := server.sessionIdManagerResolver.ResolveSessionIdManager(req) + + sessionID := resolved.Generate() + if sessionID == "" { + t.Error("Expected stateful manager to generate session ID") + } + if !strings.HasPrefix(sessionID, idPrefix) { + t.Error("Expected stateful session ID format") + } + + // Test that stateful manager validates session existence (unlike default) + _, err := resolved.Validate("unknown-session-id") + if err == nil { + t.Error("Expected stateful manager to reject unknown session ID") + } + }) } func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { @@ -2039,7 +2101,7 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { }) mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) defer testServer.Close() // Send initialize request to register session @@ -2110,7 +2172,7 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { return mcp.NewToolResultText("notification sent"), nil }) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) defer testServer.Close() // Initialize session