diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fe659cf..9b4eb7fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,58 +4,32 @@ on: branches: - main pull_request: - pull_request_target: workflow_dispatch: jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: 'go.mod' - - run: go test ./... -race - - coverage: - runs-on: ubuntu-latest - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' || (github.event_name == 'push' && github.ref == 'refs/heads/main') - permissions: - contents: read - pull-requests: write - steps: - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.sha }} - - uses: actions/setup-go@v5 - with: - go-version-file: 'go.mod' - - name: Run tests with coverage - run: | - go test -coverprofile=coverage.txt -covermode=atomic $(go list ./... | grep -v '/examples/' | grep -v '/testdata' | grep -v '/mcptest' | grep -v '/server/internal/gen') - - name: Upload coverage artifact - uses: actions/upload-artifact@v4 - with: - name: code-coverage - path: coverage.txt - retention-days: 30 - - name: Generate coverage report - uses: fgrosse/go-coverage-report@v1.2.0 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: "go.mod" + - run: go test ./... -race verify-codegen: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: 'go.mod' - - name: Run code generation - run: go generate ./... - - name: Check for uncommitted changes - run: | - if [[ -n $(git status --porcelain) ]]; then - echo "Error: Generated code is not up to date. Please run 'go generate ./...' and commit the changes." - git status - git diff - exit 1 - fi + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: "go.mod" + - name: Run code generation + run: go generate ./... + - name: Check for uncommitted changes + run: | + if [[ -n $(git status --porcelain) ]]; then + echo "Error: Generated code is not up to date. Please run 'go generate ./...' and commit the changes." + git status + git diff + exit 1 + fi diff --git a/mcp/tools.go b/mcp/tools.go index 80ae5091..42e888d5 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -653,6 +653,31 @@ func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { return json.Marshal(m) } +// UnmarshalJSON implements the json.Unmarshaler interface for ToolArgumentsSchema. +// It handles both "$defs" (JSON Schema 2019-09+) and "definitions" (JSON Schema draft-07) +// by reading either field and storing it in the Defs field. +func (tis *ToolArgumentsSchema) UnmarshalJSON(data []byte) error { + // Use a temporary type to avoid infinite recursion + type Alias ToolArgumentsSchema + aux := &struct { + Definitions map[string]any `json:"definitions,omitempty"` + *Alias + }{ + Alias: (*Alias)(tis), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + // If $defs wasn't provided but definitions was, use definitions + if tis.Defs == nil && aux.Definitions != nil { + tis.Defs = aux.Definitions + } + + return nil +} + type ToolAnnotation struct { // Human-readable title for the tool Title string `json:"title,omitempty"` diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 16aeb5df..ef472e47 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -1549,3 +1549,136 @@ func TestToolMetaMarshalingOmitsWhenNil(t *testing.T) { // Check that _meta field is not present assert.NotContains(t, result, "_meta", "Tool without Meta should not include _meta field") } + +func TestToolArgumentsSchema_UnmarshalWithDefinitions(t *testing.T) { + // Test that "definitions" (JSON Schema draft-07) is properly unmarshaled into Defs field + jsonData := `{ + "type": "object", + "properties": { + "operation": { + "$ref": "#/definitions/operation_type" + } + }, + "required": ["operation"], + "definitions": { + "operation_type": { + "type": "string", + "enum": ["create", "read", "update", "delete"] + } + } + }` + + var schema ToolArgumentsSchema + err := json.Unmarshal([]byte(jsonData), &schema) + assert.NoError(t, err) + + // Verify the schema was properly unmarshaled + assert.Equal(t, "object", schema.Type) + assert.Contains(t, schema.Properties, "operation") + assert.Equal(t, []string{"operation"}, schema.Required) + + // Most importantly: verify that "definitions" was read into Defs field + assert.NotNil(t, schema.Defs) + assert.Contains(t, schema.Defs, "operation_type") + + operationType, ok := schema.Defs["operation_type"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "string", operationType["type"]) + assert.NotNil(t, operationType["enum"]) +} + +func TestToolArgumentsSchema_UnmarshalWithDefs(t *testing.T) { + // Test that "$defs" (JSON Schema 2019-09+) is properly unmarshaled into Defs field + jsonData := `{ + "type": "object", + "properties": { + "operation": { + "$ref": "#/$defs/operation_type" + } + }, + "required": ["operation"], + "$defs": { + "operation_type": { + "type": "string", + "enum": ["create", "read", "update", "delete"] + } + } + }` + + var schema ToolArgumentsSchema + err := json.Unmarshal([]byte(jsonData), &schema) + assert.NoError(t, err) + + // Verify the schema was properly unmarshaled + assert.Equal(t, "object", schema.Type) + assert.Contains(t, schema.Properties, "operation") + assert.Equal(t, []string{"operation"}, schema.Required) + + // Verify that "$defs" was read into Defs field + assert.NotNil(t, schema.Defs) + assert.Contains(t, schema.Defs, "operation_type") + + operationType, ok := schema.Defs["operation_type"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "string", operationType["type"]) + assert.NotNil(t, operationType["enum"]) +} + +func TestToolArgumentsSchema_UnmarshalPrefersDefs(t *testing.T) { + // Test that if both "$defs" and "definitions" are present, "$defs" takes precedence + jsonData := `{ + "type": "object", + "$defs": { + "from_defs": { + "type": "string" + } + }, + "definitions": { + "from_definitions": { + "type": "integer" + } + } + }` + + var schema ToolArgumentsSchema + err := json.Unmarshal([]byte(jsonData), &schema) + assert.NoError(t, err) + + // $defs should take precedence + assert.Contains(t, schema.Defs, "from_defs") + assert.NotContains(t, schema.Defs, "from_definitions") +} + +func TestToolArgumentsSchema_MarshalRoundTrip(t *testing.T) { + // Test that marshaling and unmarshaling preserves definitions + original := ToolArgumentsSchema{ + Type: "object", + Properties: map[string]any{ + "field": map[string]any{ + "$ref": "#/$defs/my_type", + }, + }, + Required: []string{"field"}, + Defs: map[string]any{ + "my_type": map[string]any{ + "type": "string", + "enum": []string{"a", "b", "c"}, + }, + }, + } + + // Marshal + data, err := json.Marshal(original) + assert.NoError(t, err) + + // Unmarshal + var unmarshaled ToolArgumentsSchema + err = json.Unmarshal(data, &unmarshaled) + assert.NoError(t, err) + + // Verify round-trip + assert.Equal(t, original.Type, unmarshaled.Type) + assert.Equal(t, original.Required, unmarshaled.Required) + assert.NotNil(t, unmarshaled.Defs) + assert.Contains(t, unmarshaled.Defs, "my_type") +} diff --git a/server/session.go b/server/session.go index 0ded99fb..48fd52d7 100644 --- a/server/session.go +++ b/server/session.go @@ -171,6 +171,9 @@ func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { s.sessions.Range(func(k, v any) bool { if session, ok := v.(ClientSession); ok && session.Initialized() { + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } select { case session.NotificationChannel() <- notification: // Successfully sent notification diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 9e444c53..32dccc3a 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -2251,3 +2251,91 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { } }) } + +// TestStreamableHTTP_AddToolDuringToolCall tests that adding a tool while a tool call +// is in progress doesn't break the client's response. +// This is a regression test for issue #638 where notifications sent via +// sendNotificationToAllClients during an in-progress request would cause +// the response to fail with "unexpected nil response". +func TestStreamableHTTP_AddToolDuringToolCall(t *testing.T) { + mcpServer := NewMCPServer("test-mcp-server", "1.0", + WithToolCapabilities(true), // Enable tool list change notifications + ) + // Add a tool that takes some time to complete + mcpServer.AddTool(mcp.NewTool("slow_tool", + mcp.WithDescription("A tool that takes time to complete"), + ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Simulate work that takes some time + time.Sleep(100 * time.Millisecond) + return mcp.NewToolResultText("done"), nil + }) + server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) + defer server.Close() + // Initialize to get session + resp, err := postJSON(server.URL, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + sessionID := resp.Header.Get(HeaderKeySessionID) + resp.Body.Close() + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + // Start the tool call in a goroutine + resultChan := make(chan struct { + statusCode int + body string + err error + }) + go func() { + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "slow_tool", + }, + } + toolBody, _ := json.Marshal(toolRequest) + req, _ := http.NewRequest("POST", server.URL, bytes.NewReader(toolBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, sessionID) + resp, err := server.Client().Do(req) + if err != nil { + resultChan <- struct { + statusCode int + body string + err error + }{0, "", err} + return + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + resultChan <- struct { + statusCode int + body string + err error + }{resp.StatusCode, string(body), nil} + }() + // Wait a bit then add a new tool while the slow_tool is executing + // This triggers sendNotificationToAllClients + time.Sleep(50 * time.Millisecond) + mcpServer.AddTool(mcp.NewTool("new_tool", + mcp.WithDescription("A new tool added during execution"), + ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("new tool result"), nil + }) + // Wait for the tool call to complete + result := <-resultChan + if result.err != nil { + t.Fatalf("Tool call failed with error: %v", result.err) + } + if result.statusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d. Body: %s", result.statusCode, result.body) + } + // The response should contain the tool result + // It may be SSE format (text/event-stream) due to the notification upgrade + if !strings.Contains(result.body, "done") { + t.Errorf("Expected response to contain 'done', got: %s", result.body) + } +}