diff --git a/pkg/tools/registry/get_latest_module_version.go b/pkg/tools/registry/get_latest_module_version.go index f4d43448..80ce28db 100644 --- a/pkg/tools/registry/get_latest_module_version.go +++ b/pkg/tools/registry/get_latest_module_version.go @@ -45,37 +45,37 @@ func GetLatestModuleVersion(logger *log.Logger) server.ServerTool { func getLatestModuleVersionHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { modulePublisher, err := request.RequireString("module_publisher") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: 'module_publisher' (the publisher of the module)", err) + return utils.ToolError(logger, "required input: 'module_publisher' (the publisher of the module)", err) } modulePublisher = strings.ToLower(modulePublisher) moduleName, err := request.RequireString("module_name") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: 'module_name' (the name of the module)", err) + return utils.ToolError(logger, "required input: 'module_name' (the name of the module)", err) } moduleName = strings.ToLower(moduleName) moduleProvider, err := request.RequireString("module_provider") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: 'module_provider' (the provider of the module)", err) + return utils.ToolError(logger, "required input: 'module_provider' (the provider of the module)", err) } moduleProvider = strings.ToLower(moduleProvider) // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } + uri := fmt.Sprintf("modules/%s/%s/%s", modulePublisher, moduleName, moduleProvider) response, err := client.SendRegistryCall(httpClient, http.MethodGet, uri, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("fetching module information for %s/%s from the %s provider", modulePublisher, moduleName, moduleProvider), err) + return utils.ToolErrorf(logger, "fetching module information for %s/%s from the %s provider: %v", modulePublisher, moduleName, moduleProvider, err) } var moduleVersionDetails client.TerraformModuleVersionDetails if err := json.Unmarshal(response, &moduleVersionDetails); err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("unmarshalling module information for %s/%s from the %s provider", modulePublisher, moduleName, moduleProvider), err) + return utils.ToolErrorf(logger, "unmarshalling module information for %s/%s from the %s provider: %v", modulePublisher, moduleName, moduleProvider, err) } return mcp.NewToolResultText(moduleVersionDetails.Version), nil diff --git a/pkg/tools/registry/get_latest_provider_version.go b/pkg/tools/registry/get_latest_provider_version.go index 30e34a41..bb67ea69 100644 --- a/pkg/tools/registry/get_latest_provider_version.go +++ b/pkg/tools/registry/get_latest_provider_version.go @@ -5,7 +5,6 @@ package tools import ( "context" - "fmt" "strings" "github.com/hashicorp/terraform-mcp-server/pkg/client" @@ -40,26 +39,24 @@ func GetLatestProviderVersion(logger *log.Logger) server.ServerTool { func getLatestProviderVersionHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { namespace, err := request.RequireString("namespace") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: namespace of the Terraform provider is required", err) + return utils.ToolError(logger, "missing required input: namespace", err) } namespace = strings.ToLower(namespace) name, err := request.RequireString("name") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: name of the Terraform provider is required", err) + return utils.ToolError(logger, "missing required input: name", err) } name = strings.ToLower(name) - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } version, err := client.GetLatestProviderVersion(httpClient, namespace, name, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, "fetching latest provider version", err) + return utils.ToolErrorf(logger, "provider not found: %s/%s - verify the namespace and provider name are correct", namespace, name) } return mcp.NewToolResultText(version), nil diff --git a/pkg/tools/registry/get_module_details.go b/pkg/tools/registry/get_module_details.go index 4a5130d6..a2748e3b 100644 --- a/pkg/tools/registry/get_module_details.go +++ b/pkg/tools/registry/get_module_details.go @@ -42,40 +42,37 @@ func ModuleDetails(logger *log.Logger) server.ServerTool { func getModuleDetailsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { moduleID, err := request.RequireString("module_id") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: module_id is required", err) + return utils.ToolError(logger, "missing required input: module_id", err) } if moduleID == "" { - return nil, utils.LogAndReturnError(logger, "required input: module_id cannot be empty", nil) + return utils.ToolError(logger, "module_id cannot be empty", nil) } - + // Validate module ID format if err := validateModuleID(moduleID); err != nil { - return nil, utils.LogAndReturnError(logger, err.Error(), nil) + return utils.ToolError(logger, err.Error(), nil) } - + moduleID = strings.ToLower(moduleID) - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } - var errMsg string response, err := getModuleDetails(httpClient, moduleID, 0, logger) if err != nil { - errMsg = fmt.Sprintf("getting module(s), none found! module_id: %v,", moduleID) - return nil, utils.LogAndReturnError(logger, errMsg, nil) + return utils.ToolErrorf(logger, "module not found: %s - use search_modules first to find valid module IDs", moduleID) } + moduleData, err := unmarshalTerraformModule(response) if err != nil { - return nil, utils.LogAndReturnError(logger, "unmarshalling module details", err) + return utils.ToolError(logger, "failed to parse module details", err) } if moduleData == "" { - errMsg = fmt.Sprintf("getting module(s), none found! %s please provider a different moduleProvider", errMsg) - return nil, utils.LogAndReturnError(logger, errMsg, nil) + return utils.ToolErrorf(logger, "no module data returned for %s - try a different module_id", moduleID) } + return mcp.NewToolResultText(moduleData), nil } @@ -88,20 +85,17 @@ func getModuleDetails(httpClient *http.Client, moduleID string, currentOffset in uri = fmt.Sprintf("%s?offset=%v", uri, currentOffset) response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) if err != nil { - // We shouldn't log the error here because we might hit a namespace that doesn't exist, it's better to let the caller handle it. return nil, fmt.Errorf("getting module(s) for: %v, please provide a different provider name like aws, azurerm or google etc", moduleID) } - // Return the filtered JSON as a string return response, nil } func unmarshalTerraformModule(response []byte) (string, error) { - // Handles one module var terraformModules client.TerraformModuleVersionDetails err := json.Unmarshal(response, &terraformModules) if err != nil { - return "", utils.LogAndReturnError(nil, "unmarshalling module details", err) + return "", fmt.Errorf("unmarshalling module details: %w", err) } var builder strings.Builder @@ -120,7 +114,7 @@ func unmarshalTerraformModule(response []byte) (string, error) { builder.WriteString(fmt.Sprintf("| %s | %s | %s | `%v` | %t |\n", input.Name, input.Type, - input.Description, // Consider cleaning potential newlines/markdown + input.Description, input.Default, input.Required, )) @@ -136,7 +130,7 @@ func unmarshalTerraformModule(response []byte) (string, error) { for _, output := range terraformModules.Root.Outputs { builder.WriteString(fmt.Sprintf("| %s | %s |\n", output.Name, - output.Description, // Consider cleaning potential newlines/markdown + output.Description, )) } builder.WriteString("\n") @@ -163,11 +157,8 @@ func unmarshalTerraformModule(response []byte) (string, error) { builder.WriteString("### Examples\n\n") for _, example := range terraformModules.Examples { builder.WriteString(fmt.Sprintf("#### %s\n\n", example.Name)) - // Optionally, include more details from example if needed, like inputs/outputs - // For now, just listing the name. if example.Readme != "" { builder.WriteString("**Readme:**\n\n") - // Append readme content, potentially needs markdown escaping/sanitization depending on source builder.WriteString(example.Readme) builder.WriteString("\n\n") } diff --git a/pkg/tools/registry/get_policy_details.go b/pkg/tools/registry/get_policy_details.go index 4272e95e..f30caca3 100644 --- a/pkg/tools/registry/get_policy_details.go +++ b/pkg/tools/registry/get_policy_details.go @@ -41,26 +41,25 @@ func PolicyDetails(logger *log.Logger) server.ServerTool { func getPolicyDetailsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { terraformPolicyID, err := request.RequireString("terraform_policy_id") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: terraform_policy_id is required and must be a string, it is fetched by running the search_policies tool", err) + return utils.ToolError(logger, "missing required input: terraform_policy_id - use search_policies first to find valid policy IDs", err) } if terraformPolicyID == "" { - return nil, utils.LogAndReturnError(logger, "required input: terraform_policy_id cannot be empty, it is fetched by running the search_policies tool", nil) + return utils.ToolError(logger, "terraform_policy_id cannot be empty - use search_policies first to find valid policy IDs", nil) } - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } + policyResp, err := client.SendRegistryCall(httpClient, "GET", (&url.URL{Path: terraformPolicyID, RawQuery: url.Values{"include": {"policies,policy-modules,policy-library"}}.Encode()}).String(), logger, "v2") if err != nil { - return nil, utils.LogAndReturnError(logger, "fetching policy details: registry API did not return a successful response", err) + return utils.ToolErrorf(logger, "policy not found: %s - verify the terraform_policy_id is correct or use search_policies to find valid IDs", terraformPolicyID) } var policyDetails client.TerraformPolicyDetails if err := json.Unmarshal(policyResp, &policyDetails); err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("unmarshalling policy details for %s", terraformPolicyID), err) + return utils.ToolErrorf(logger, "failed to parse policy details for %s", terraformPolicyID) } readme := utils.ExtractReadme(policyDetails.Data.Attributes.Readme) @@ -70,7 +69,6 @@ func getPolicyDetailsHandler(ctx context.Context, request mcp.CallToolRequest, l moduleList := "" for _, policy := range policyDetails.Included { if policy.Type == "policy-modules" { - // Use text/template to safely build the module block var moduleBuilder strings.Builder tmpl := ` module "{{.Name}}" { @@ -103,7 +101,6 @@ module "{{.Name}}" { builder.WriteString("## Usage\n\n") builder.WriteString("Generate the content for a HashiCorp Configuration Language (HCL) file named policies.hcl. This file should define a set of policies. For each policy provided, create a distinct policy block using the following template.\n") builder.WriteString("\n```hcl\n") - // Use text/template to safely build the HCL template for policies hclTmpl := ` {{- if .ModuleList }} {{ .ModuleList }} diff --git a/pkg/tools/registry/get_provider_capabilities.go b/pkg/tools/registry/get_provider_capabilities.go index 81c4485f..44098f52 100644 --- a/pkg/tools/registry/get_provider_capabilities.go +++ b/pkg/tools/registry/get_provider_capabilities.go @@ -57,52 +57,46 @@ Returns a summary with counts and examples for each capability type.`), func getProviderCapabilitiesHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { namespace, err := request.RequireString("namespace") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: namespace of the Terraform provider is required", err) + return utils.ToolError(logger, "missing required input: namespace", err) } namespace = strings.ToLower(namespace) name, err := request.RequireString("name") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: name of the Terraform provider is required", err) + return utils.ToolError(logger, "missing required input: name", err) } name = strings.ToLower(name) version := request.GetString("version", "latest") if version == "latest" || !utils.IsValidProviderVersionFormat(version) { - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } latestVersion, err := client.GetLatestProviderVersion(httpClient, namespace, name, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, "fetching latest provider version", err) + return utils.ToolErrorf(logger, "provider not found: %s/%s - verify the namespace and provider name are correct", namespace, name) } version = latestVersion } - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } - // Get provider documentation uri := fmt.Sprintf("providers/%s/%s/%s", namespace, name, version) response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("fetching provider docs for %s/%s:%s", namespace, name, version), err) + return utils.ToolErrorf(logger, "failed to fetch provider docs for %s/%s:%s - verify the provider exists", namespace, name, version) } var providerDocs client.ProviderDocs if err := json.Unmarshal(response, &providerDocs); err != nil { - return nil, utils.LogAndReturnError(logger, "unmarshalling provider docs", err) + return utils.ToolErrorf(logger, "failed to parse provider docs for %s/%s:%s", namespace, name, version) } - // Analyze and format capabilities output := analyzeAndFormatCapabilities(providerDocs, namespace, name, version) return mcp.NewToolResultText(output), nil } @@ -110,7 +104,6 @@ func getProviderCapabilitiesHandler(ctx context.Context, request mcp.CallToolReq func analyzeAndFormatCapabilities(docs client.ProviderDocs, namespace, name, version string) string { capabilities := make(map[string][]client.ProviderDoc) - // Analyze documentation for _, doc := range docs.Docs { if doc.Language != "hcl" { continue @@ -128,13 +121,11 @@ func analyzeAndFormatCapabilities(docs client.ProviderDocs, namespace, name, ver return builder.String() } - // Show all capabilities as discovered for capType, items := range capabilities { title := strings.ReplaceAll(capType, "-", " ") title = cases.Title(language.English).String(title) builder.WriteString(fmt.Sprintf("%s: %d available\n", title, len(items))) - // Dynamic listing: show all if ≤10, otherwise show 3 with "more" message limit := 3 if len(items) <= 10 { limit = len(items) diff --git a/pkg/tools/registry/get_provider_details.go b/pkg/tools/registry/get_provider_details.go index 0f19aca5..55e978bf 100644 --- a/pkg/tools/registry/get_provider_details.go +++ b/pkg/tools/registry/get_provider_details.go @@ -6,7 +6,6 @@ package tools import ( "context" "encoding/json" - "fmt" "path" "strconv" @@ -40,30 +39,29 @@ You must call 'search_providers' tool first to obtain the exact tfprovider-compa func getProviderDocsHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { providerDocID, err := request.RequireString("provider_doc_id") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: provider_doc_id is required", err) + return utils.ToolError(logger, "missing required input: provider_doc_id", err) } if providerDocID == "" { - return nil, utils.LogAndReturnError(logger, "required input: provider_doc_id cannot be empty", nil) + return utils.ToolError(logger, "provider_doc_id cannot be empty", nil) } if _, err := strconv.Atoi(providerDocID); err != nil { - return nil, utils.LogAndReturnError(logger, "required input: provider_doc_id must be a valid number", err) + return utils.ToolError(logger, "provider_doc_id must be a valid number - use search_providers first to find valid IDs", err) } - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } detailResp, err := client.SendRegistryCall(httpClient, "GET", path.Join("provider-docs", providerDocID), logger, "v2") if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("fetching provider-docs/%s, please make sure provider_doc_id is valid and the search_providers tool has run prior", providerDocID), err) + return utils.ToolErrorf(logger, "provider doc not found: %s - use search_providers first to find valid provider_doc_id values", providerDocID) } var details client.ProviderResourceDetails if err := json.Unmarshal(detailResp, &details); err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("unmarshalling provider-docs/%s", providerDocID), err) + return utils.ToolErrorf(logger, "failed to parse provider docs for %s", providerDocID) } + return mcp.NewToolResultText(details.Data.Attributes.Content), nil } diff --git a/pkg/tools/registry/search_modules.go b/pkg/tools/registry/search_modules.go index 3e9f79a1..961c5e15 100644 --- a/pkg/tools/registry/search_modules.go +++ b/pkg/tools/registry/search_modules.go @@ -55,33 +55,30 @@ If no modules were found, reattempt the search with a new moduleName query.`), func getSearchModulesHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { moduleQuery, err := request.RequireString("module_query") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: module_query is required", err) + return utils.ToolError(logger, "missing required input: module_query", err) } moduleQuery = strings.ToLower(moduleQuery) currentOffsetValue := request.GetInt("current_offset", 0) - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } - var modulesData, errMsg string response, err := sendSearchModulesCall(httpClient, moduleQuery, currentOffsetValue, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("finding module(s): none found for moduleName: %s", moduleQuery), err) - } else { - modulesData, err = unmarshalTerraformModules(response, moduleQuery, logger) - if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf("unmarshalling modules for moduleName: %s", moduleQuery), err) - } + return utils.ToolErrorf(logger, "no modules found for query: %s - try a different search term", moduleQuery) + } + + modulesData, err := unmarshalTerraformModules(response, moduleQuery, logger) + if err != nil { + return utils.ToolErrorf(logger, "failed to parse module results for query: %s", moduleQuery) } if modulesData == "" { - errMsg = fmt.Sprintf("getting module(s), none found! query used: %s; error: %s", moduleQuery, errMsg) - return nil, utils.LogAndReturnError(logger, errMsg, nil) + return utils.ToolErrorf(logger, "no modules found for query: %s - try a different search term", moduleQuery) } + return mcp.NewToolResultText(modulesData), nil } @@ -95,27 +92,23 @@ func sendSearchModulesCall(providerClient *http.Client, moduleQuery string, curr response, err := client.SendRegistryCall(providerClient, "GET", uri, logger) if err != nil { - // We shouldn't log the error here because we might hit a namespace that doesn't exist, it's better to let the caller handle it. return nil, fmt.Errorf("getting module(s) for: %v, call error: %v", moduleQuery, err) } - // Return the filtered JSON as a string return response, nil } func unmarshalTerraformModules(response []byte, moduleQuery string, logger *log.Logger) (string, error) { - // Get the list of modules var terraformModules client.TerraformModules err := json.Unmarshal(response, &terraformModules) if err != nil { - return "", utils.LogAndReturnError(logger, "unmarshalling modules", err) + return "", fmt.Errorf("unmarshalling modules: %w", err) } if len(terraformModules.Data) == 0 { - return "", utils.LogAndReturnError(logger, fmt.Sprintf("no modules found for query: %s", moduleQuery), nil) + return "", fmt.Errorf("no modules found for query: %s", moduleQuery) } - // Sort by most downloaded sort.Slice(terraformModules.Data, func(i, j int) bool { return terraformModules.Data[i].Downloads > terraformModules.Data[j].Downloads }) diff --git a/pkg/tools/registry/search_policies.go b/pkg/tools/registry/search_policies.go index eedcc5b7..3b750087 100644 --- a/pkg/tools/registry/search_policies.go +++ b/pkg/tools/registry/search_policies.go @@ -47,37 +47,36 @@ If no policies were found, reattempt the search with a new policy_query.`), } func getSearchPoliciesHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { - var terraformPolicies client.TerraformPolicyList pq, err := request.RequireString("policy_query") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: policy_query is required", err) + return utils.ToolError(logger, "missing required input: policy_query", err) } if pq == "" { - return nil, utils.LogAndReturnError(logger, "required input: policy_query cannot be empty", nil) + return utils.ToolError(logger, "policy_query cannot be empty", nil) } pq = strings.ToLower(pq) - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } + uri := (&url.URL{ Path: "policies", RawQuery: url.Values{ - "page[size]": {"100"}, // static list of 100 is fine for now + "page[size]": {"100"}, "include": {"latest-version"}, }.Encode(), }).String() + policyResp, err := client.SendRegistryCall(httpClient, "GET", uri, logger, "v2") if err != nil { - return nil, utils.LogAndReturnError(logger, "fetching policies: registry API did not return a successful response", err) + return utils.ToolError(logger, "failed to fetch policies from registry", err) } - err = json.Unmarshal(policyResp, &terraformPolicies) - if err != nil { - return nil, utils.LogAndReturnError(logger, "unmarshalling policy list", err) + var terraformPolicies client.TerraformPolicyList + if err := json.Unmarshal(policyResp, &terraformPolicies); err != nil { + return utils.ToolError(logger, "failed to parse policy list", err) } var builder strings.Builder @@ -101,11 +100,9 @@ func getSearchPoliciesHandler(ctx context.Context, request mcp.CallToolRequest, } } - policyData := builder.String() if !contentAvailable { - errMessage := fmt.Sprintf("finding policies, none found matching the query: %s. Try a different policy_query.", pq) - return nil, utils.LogAndReturnError(logger, errMessage, nil) + return utils.ToolErrorf(logger, "no policies found matching query: %s - try a different search term", pq) } - return mcp.NewToolResultText(policyData), nil + return mcp.NewToolResultText(builder.String()), nil } diff --git a/pkg/tools/registry/search_providers.go b/pkg/tools/registry/search_providers.go index d4f20cf8..db8680d2 100644 --- a/pkg/tools/registry/search_providers.go +++ b/pkg/tools/registry/search_providers.go @@ -68,26 +68,24 @@ for listing resources using Terraform Search use 'list-resources'`), } func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) { - // For typical provider and namespace hallucinations defaultErrorGuide := "please check the provider name, provider namespace or the provider version you're looking for, perhaps the provider is published under a different namespace or company name" - // Get a simple http client to access the public Terraform registry from context httpClient, err := client.GetHttpClientFromContext(ctx, logger) if err != nil { - logger.WithError(err).Error("failed to get http client for public Terraform registry") - return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil + return utils.ToolError(logger, "failed to get http client for public Terraform registry", err) } - providerDetail, err := resolveProviderDetails(request, httpClient, defaultErrorGuide, logger) + + providerDetail, err := resolveProviderDetails(request, httpClient, logger) if err != nil { - return nil, err + return utils.ToolErrorf(logger, "failed to resolve provider: %v - %s", err, defaultErrorGuide) } serviceSlug, err := request.RequireString("service_slug") if err != nil { - return nil, utils.LogAndReturnError(logger, "required input: service_slug is required", err) + return utils.ToolError(logger, "missing required input: service_slug", err) } if serviceSlug == "" { - return nil, utils.LogAndReturnError(logger, "required input: service_slug cannot be empty", nil) + return utils.ToolError(logger, "service_slug cannot be empty", nil) } serviceSlug = strings.ToLower(serviceSlug) @@ -98,9 +96,8 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques if utils.IsV2ProviderDocumentType(providerDetail.ProviderDocumentType) { content, err := providerDetailsV2(httpClient, providerDetail, logger) if err != nil { - errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`, + return utils.ToolErrorf(logger, "failed to find %s documentation for provider '%s' in the '%s' namespace - %s", providerDetail.ProviderDocumentType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide) - return nil, utils.LogAndReturnError(logger, errMessage, err) } fullContent := fmt.Sprintf("# %s provider docs\n\n%s", @@ -113,12 +110,13 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil) + return utils.ToolErrorf(logger, "failed to get provider '%s' version '%s' in namespace '%s' - %s", + providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide) } var providerDocs client.ProviderDocs if err := json.Unmarshal(response, &providerDocs); err != nil { - return nil, utils.LogAndReturnError(logger, "unmarshalling provider docs", err) + return utils.ToolError(logger, "failed to parse provider docs", err) } var builder strings.Builder @@ -142,25 +140,24 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques } } - // Check if the content data is not fulfilled if !contentAvailable { - errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug) - return nil, utils.LogAndReturnError(logger, errMessage, err) + return utils.ToolErrorf(logger, "no documentation found for service_slug '%s' - try a more relevant service_slug, or use the provider_name as the value", serviceSlug) } + return mcp.NewToolResultText(builder.String()), nil } -func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) { +func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, logger *log.Logger) (client.ProviderDetail, error) { providerDetail := client.ProviderDetail{} providerName := request.GetString("provider_name", "") if providerName == "" { - return providerDetail, fmt.Errorf("provider_name is required and must be a string") + return providerDetail, fmt.Errorf("provider_name is required") } providerName = strings.ToLower(providerName) providerNamespace := request.GetString("provider_namespace", "") if providerNamespace == "" { - logger.Debugf(`Error getting latest provider version in "%s" namespace, trying the hashicorp namespace`, providerNamespace) + logger.Debugf(`provider_namespace not provided, trying the hashicorp namespace`) providerNamespace = "hashicorp" } providerNamespace = strings.ToLower(providerNamespace) @@ -188,13 +185,13 @@ func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client tryProviderNamespace := "hashicorp" providerVersionValue, err = client.GetLatestProviderVersion(httpClient, tryProviderNamespace, providerName, logger) if err != nil { - // Just so we don't print the same namespace twice if they are the same + namespaceTried := providerNamespace if providerNamespace != tryProviderNamespace { - tryProviderNamespace = fmt.Sprintf(`"%s" or the "%s"`, providerNamespace, tryProviderNamespace) + namespaceTried = fmt.Sprintf("'%s' or '%s'", providerNamespace, tryProviderNamespace) } - return providerDetail, utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerName, providerVersion, tryProviderNamespace, defaultErrorGuide), nil) + return providerDetail, fmt.Errorf("provider '%s' version '%s' not found in namespace %s", providerName, providerVersion, namespaceTried) } - providerNamespace = tryProviderNamespace // Update the namespace to hashicorp, if successful + providerNamespace = tryProviderNamespace } providerDocumentTypeValue := "" @@ -209,12 +206,13 @@ func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client return providerDetail, nil } -// providerDetailsV2 retrieves a list of documentation items for a specific provider category using v2 API with support for pagination using page numbers +// providerDetailsV2 retrieves a list of documentation items for a specific provider category using v2 API func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDetail, logger *log.Logger) (string, error) { providerVersionID, err := client.GetProviderVersionID(httpClient, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion, logger) if err != nil { - return "", utils.LogAndReturnError(logger, "getting provider version ID", err) + return "", fmt.Errorf("getting provider version ID: %w", err) } + category := providerDetail.ProviderDocumentType if category == "overview" { return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger) @@ -225,7 +223,7 @@ func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDe docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger) if err != nil { - return "", utils.LogAndReturnError(logger, "getting provider documentation", err) + return "", fmt.Errorf("getting provider documentation: %w", err) } if len(docs) == 0 { @@ -250,15 +248,15 @@ func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDe func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger) (string, error) { docContent, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("provider-docs/%s", docID), logger, "v2") if err != nil { - return "", utils.LogAndReturnError(logger, fmt.Sprintf("fetching provider-docs/%s within getContentSnippet", docID), err) + return "", fmt.Errorf("fetching provider-docs/%s: %w", docID, err) } + var docDescription client.ProviderResourceDetails if err := json.Unmarshal(docContent, &docDescription); err != nil { - return "", utils.LogAndReturnError(logger, fmt.Sprintf("unmarshalling provider-docs/%s within getContentSnippet", docID), err) + return "", fmt.Errorf("unmarshalling provider-docs/%s: %w", docID, err) } content := docDescription.Data.Attributes.Content - // Try to extract description from markdown content desc := "" if start := strings.Index(content, "description: |-"); start != -1 { if end := strings.Index(content[start:], "\n---"); end != -1 { diff --git a/pkg/utils/errors.go b/pkg/utils/errors.go new file mode 100644 index 00000000..90853a45 --- /dev/null +++ b/pkg/utils/errors.go @@ -0,0 +1,34 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package utils + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + log "github.com/sirupsen/logrus" +) + +// ToolError returns a Tool Execution Error that the model can see and learn from. +// Unlike Protocol Errors, Tool Execution Errors are returned to the LLM's context window + +func ToolError(logger *log.Logger, message string, err error) (*mcp.CallToolResult, error) { + fullMessage := message + if err != nil { + fullMessage = fmt.Sprintf("%s: %v", message, err) + } + if logger != nil { + logger.Errorf("Tool error: %s", fullMessage) + } + return mcp.NewToolResultError(fullMessage), nil +} + +// ToolErrorf returns a formatted Tool Execution Error that the model can see. +func ToolErrorf(logger *log.Logger, format string, args ...interface{}) (*mcp.CallToolResult, error) { + message := fmt.Sprintf(format, args...) + if logger != nil { + logger.Errorf("Tool error: %s", message) + } + return mcp.NewToolResultError(message), nil +}