From 55a43d4d7341a2ff9e7e465ca0150c4677406bc5 Mon Sep 17 00:00:00 2001 From: Yeliz Henden Date: Thu, 9 Apr 2026 15:43:43 +0100 Subject: [PATCH 1/3] feat(mcp): add search tool --- tools/mcp-server/internal/tools/search.go | 489 ++++++++++++++ .../mcp-server/internal/tools/search_test.go | 634 ++++++++++++++++++ tools/mcp-server/internal/tools/tools.go | 15 + 3 files changed, 1138 insertions(+) create mode 100644 tools/mcp-server/internal/tools/search.go create mode 100644 tools/mcp-server/internal/tools/search_test.go diff --git a/tools/mcp-server/internal/tools/search.go b/tools/mcp-server/internal/tools/search.go new file mode 100644 index 0000000000..9a39866fc5 --- /dev/null +++ b/tools/mcp-server/internal/tools/search.go @@ -0,0 +1,489 @@ +package tools + +import ( + "fmt" + "regexp" + "strings" + + "github.com/mongodb/openapi/tools/mcp-server/internal/registry" + "github.com/oasdiff/kin-openapi/openapi3" +) + +// SearchParams are the parameters for the search tool. +type SearchParams struct { + Alias string `json:"alias" jsonschema:"Alias of the spec to search"` + Pattern string `json:"pattern" jsonschema:"Regular expression pattern to search for"` + SearchIn []string `json:"searchIn,omitempty" jsonschema:"Optional: categories to search (operations, schemas, parameters, responses, tags, paths). Default: all"` + CaseSensitive bool `json:"caseSensitive,omitempty" jsonschema:"Optional: case-sensitive search (default: false)"` + Limit int `json:"limit,omitempty" jsonschema:"Optional: maximum results per category (default: 100)"` +} + +// SearchResult is the result of a search operation. +type SearchResult struct { + Success bool `json:"success"` + Alias string `json:"alias"` + Pattern string `json:"pattern"` + Operations []OperationMatch `json:"operations"` + Schemas []SchemaMatch `json:"schemas"` + Parameters []ParameterMatch `json:"parameters"` + Responses []ResponseMatch `json:"responses"` + Tags []TagMatch `json:"tags"` + Paths []PathMatch `json:"paths"` + Pagination PaginationMetadata `json:"pagination"` +} + +// PaginationMetadata contains pagination information. +type PaginationMetadata struct { + Limit int `json:"limit"` + TotalMatches int `json:"totalMatches"` + CategoryCounts map[string]int `json:"categoryCounts"` + CategoryHasMore map[string]bool `json:"categoryHasMore,omitempty"` +} + +// OperationMatch represents a matched operation. +type OperationMatch struct { + Path string `json:"path"` + Method string `json:"method"` + OperationID string `json:"operationId,omitempty"` + Summary string `json:"summary,omitempty"` + Description string `json:"description,omitempty"` + Tags []string `json:"tags,omitempty"` + MatchedIn []string `json:"matchedIn"` + MatchedText string `json:"matchedText"` +} + +// SchemaMatch represents a matched schema. +type SchemaMatch struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + MatchedIn []string `json:"matchedIn"` + MatchedText string `json:"matchedText"` + MatchedProperties []string `json:"matchedProperties,omitempty"` +} + +// ParameterMatch represents a matched parameter. +type ParameterMatch struct { + Name string `json:"name"` + In string `json:"in,omitempty"` + Description string `json:"description,omitempty"` + MatchedIn []string `json:"matchedIn"` + MatchedText string `json:"matchedText"` +} + +// ResponseMatch represents a matched response. +type ResponseMatch struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + MatchedIn []string `json:"matchedIn"` + MatchedText string `json:"matchedText"` +} + +// TagMatch represents a matched tag. +type TagMatch struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + MatchedIn []string `json:"matchedIn"` + MatchedText string `json:"matchedText"` +} + +// PathMatch represents a matched path. +type PathMatch struct { + Path string `json:"path"` + MatchedIn []string `json:"matchedIn"` + MatchedText string `json:"matchedText"` +} + +// validSearchCategories defines the valid categories for searchIn parameter. +var validSearchCategories = map[string]bool{ + "operations": true, + "schemas": true, + "parameters": true, + "responses": true, + "tags": true, + "paths": true, +} + +// Helper methods for SearchResult + +// totalCount returns the total number of matches across all categories. +func (r *SearchResult) totalCount() int { + return len(r.Operations) + len(r.Schemas) + len(r.Parameters) + + len(r.Responses) + len(r.Tags) + len(r.Paths) +} + +// categoryCounts returns a map of category names to their match counts. +func (r *SearchResult) categoryCounts() map[string]int { + return map[string]int{ + "operations": len(r.Operations), + "schemas": len(r.Schemas), + "parameters": len(r.Parameters), + "responses": len(r.Responses), + "tags": len(r.Tags), + "paths": len(r.Paths), + } +} + +// Helper functions for matching + +// checkAndRecordMatch checks if a value matches the regex and records the match. +func checkAndRecordMatch(re *regexp.Regexp, fieldName, value string, matchedIn, matchedTexts *[]string) { + if value != "" && re.MatchString(value) { + *matchedIn = append(*matchedIn, fieldName) + *matchedTexts = append(*matchedTexts, re.FindString(value)) + } +} + +// handleSearch searches for matches in an OpenAPI spec. +func handleSearch(reg *registry.Registry, params SearchParams) (SearchResult, error) { + // Set defaults + if params.Limit == 0 { + params.Limit = 100 + } + + // Validate searchIn categories + if len(params.SearchIn) > 0 { + var invalidCategories []string + for _, category := range params.SearchIn { + if !validSearchCategories[category] { + invalidCategories = append(invalidCategories, category) + } + } + if len(invalidCategories) > 0 { + return SearchResult{Success: false}, fmt.Errorf( + "invalid searchIn categories: %v. Valid categories are: operations, schemas, parameters, responses, tags, paths", + invalidCategories, + ) + } + } + + // Validate and compile regex + var re *regexp.Regexp + var err error + if params.CaseSensitive { + re, err = regexp.Compile(params.Pattern) + } else { + re, err = regexp.Compile("(?i)" + params.Pattern) + } + if err != nil { + return SearchResult{Success: false}, fmt.Errorf("invalid regex pattern: %w", err) + } + + // Get spec from registry + entry, err := reg.GetByAlias(params.Alias) + if err != nil { + return SearchResult{Success: false}, err + } + + // Determine what to search + searchIn := params.SearchIn + if len(searchIn) == 0 { + searchIn = []string{"operations", "schemas", "parameters", "responses", "tags", "paths"} + } + + // Perform search + result := SearchResult{ + Success: true, + Alias: params.Alias, + Pattern: params.Pattern, + } + + for _, category := range searchIn { + switch category { + case "operations": + result.Operations = searchOperations(entry.Spec, re) + case "schemas": + result.Schemas = searchSchemas(entry.Spec, re) + case "parameters": + result.Parameters = searchParameters(entry.Spec, re) + case "responses": + result.Responses = searchResponses(entry.Spec, re) + case "tags": + result.Tags = searchTags(entry.Spec, re) + case "paths": + result.Paths = searchPaths(entry.Spec, re) + } + } + + // Apply per-category pagination + result = applyPagination(result, params.Limit) + + return result, nil +} + +// searchOperations searches for matches in operations. +func searchOperations(spec *openapi3.T, re *regexp.Regexp) []OperationMatch { + var matches []OperationMatch + + if spec.Paths == nil { + return matches + } + + for path, pathItem := range spec.Paths.Map() { + for method, operation := range pathItem.Operations() { + match := OperationMatch{ + Path: path, + Method: strings.ToUpper(method), + OperationID: operation.OperationID, + Summary: operation.Summary, + Description: operation.Description, + Tags: operation.Tags, + } + + var matchedIn []string + var matchedTexts []string + + // Search in various fields + checkAndRecordMatch(re, "operationId", operation.OperationID, &matchedIn, &matchedTexts) + checkAndRecordMatch(re, "summary", operation.Summary, &matchedIn, &matchedTexts) + checkAndRecordMatch(re, "description", operation.Description, &matchedIn, &matchedTexts) + + // Search in tags + for _, tag := range operation.Tags { + if re.MatchString(tag) { + matchedIn = append(matchedIn, "tags") + matchedTexts = append(matchedTexts, re.FindString(tag)) + break + } + } + + if len(matchedIn) > 0 { + match.MatchedIn = matchedIn + match.MatchedText = strings.Join(matchedTexts, ", ") + matches = append(matches, match) + } + } + } + + return matches +} + +// searchSchemas searches for matches in component schemas. +func searchSchemas(spec *openapi3.T, re *regexp.Regexp) []SchemaMatch { + var matches []SchemaMatch + + if spec.Components == nil || spec.Components.Schemas == nil { + return matches + } + + for name, schemaRef := range spec.Components.Schemas { + if schemaRef == nil || schemaRef.Value == nil { + continue + } + + schema := schemaRef.Value + match := SchemaMatch{ + Name: name, + Description: schema.Description, + } + + var matchedIn []string + var matchedTexts []string + var matchedProps []string + + // Search in schema name and description + checkAndRecordMatch(re, "name", name, &matchedIn, &matchedTexts) + checkAndRecordMatch(re, "description", schema.Description, &matchedIn, &matchedTexts) + + // Search in property names + if schema.Properties != nil { + for propName := range schema.Properties { + if re.MatchString(propName) { + matchedIn = append(matchedIn, "properties") + matchedProps = append(matchedProps, propName) + matchedTexts = append(matchedTexts, re.FindString(propName)) + } + } + } + + if len(matchedIn) > 0 { + match.MatchedIn = matchedIn + match.MatchedText = strings.Join(matchedTexts, ", ") + if len(matchedProps) > 0 { + match.MatchedProperties = matchedProps + } + matches = append(matches, match) + } + } + + return matches +} + +// searchParameters searches for matches in component parameters. +func searchParameters(spec *openapi3.T, re *regexp.Regexp) []ParameterMatch { + var matches []ParameterMatch + + if spec.Components == nil || spec.Components.Parameters == nil { + return matches + } + + for name, paramRef := range spec.Components.Parameters { + if paramRef == nil || paramRef.Value == nil { + continue + } + + param := paramRef.Value + match := ParameterMatch{ + Name: param.Name, + In: param.In, + Description: param.Description, + } + + var matchedIn []string + var matchedTexts []string + + // Search in parameter name (check both component name and param.Name) + if re.MatchString(name) || re.MatchString(param.Name) { + matchedIn = append(matchedIn, "name") + matchedTexts = append(matchedTexts, re.FindString(param.Name)) + } + + // Search in description + checkAndRecordMatch(re, "description", param.Description, &matchedIn, &matchedTexts) + + if len(matchedIn) > 0 { + match.MatchedIn = matchedIn + match.MatchedText = strings.Join(matchedTexts, ", ") + matches = append(matches, match) + } + } + + return matches +} + +// searchResponses searches for matches in component responses. +func searchResponses(spec *openapi3.T, re *regexp.Regexp) []ResponseMatch { + var matches []ResponseMatch + + if spec.Components == nil || spec.Components.Responses == nil { + return matches + } + + for name, respRef := range spec.Components.Responses { + if respRef == nil || respRef.Value == nil { + continue + } + + resp := respRef.Value + description := "" + if resp.Description != nil { + description = *resp.Description + } + + match := ResponseMatch{ + Name: name, + Description: description, + } + + var matchedIn []string + var matchedTexts []string + + // Search in response name and description + checkAndRecordMatch(re, "name", name, &matchedIn, &matchedTexts) + checkAndRecordMatch(re, "description", description, &matchedIn, &matchedTexts) + + if len(matchedIn) > 0 { + match.MatchedIn = matchedIn + match.MatchedText = strings.Join(matchedTexts, ", ") + matches = append(matches, match) + } + } + + return matches +} + +// searchTags searches for matches in tags. +func searchTags(spec *openapi3.T, re *regexp.Regexp) []TagMatch { + var matches []TagMatch + + if spec.Tags == nil { + return matches + } + + for _, tag := range spec.Tags { + match := TagMatch{ + Name: tag.Name, + Description: tag.Description, + } + + var matchedIn []string + var matchedTexts []string + + // Search in tag name and description + checkAndRecordMatch(re, "name", tag.Name, &matchedIn, &matchedTexts) + checkAndRecordMatch(re, "description", tag.Description, &matchedIn, &matchedTexts) + + if len(matchedIn) > 0 { + match.MatchedIn = matchedIn + match.MatchedText = strings.Join(matchedTexts, ", ") + matches = append(matches, match) + } + } + + return matches +} + +// searchPaths searches for matches in path patterns. +func searchPaths(spec *openapi3.T, re *regexp.Regexp) []PathMatch { + var matches []PathMatch + + if spec.Paths == nil { + return matches + } + + for path := range spec.Paths.Map() { + if re.MatchString(path) { + match := PathMatch{ + Path: path, + MatchedIn: []string{"path"}, + MatchedText: re.FindString(path), + } + matches = append(matches, match) + } + } + + return matches +} + +// applyPagination applies per-category limit to search results. +func applyPagination(result SearchResult, limit int) SearchResult { + // Store counts before truncation + totalMatches := result.totalCount() + categoryCounts := result.categoryCounts() + categoryHasMore := make(map[string]bool) + + // Apply limit to each category + if len(result.Operations) > limit { + result.Operations = result.Operations[:limit] + categoryHasMore["operations"] = true + } + if len(result.Schemas) > limit { + result.Schemas = result.Schemas[:limit] + categoryHasMore["schemas"] = true + } + if len(result.Parameters) > limit { + result.Parameters = result.Parameters[:limit] + categoryHasMore["parameters"] = true + } + if len(result.Responses) > limit { + result.Responses = result.Responses[:limit] + categoryHasMore["responses"] = true + } + if len(result.Tags) > limit { + result.Tags = result.Tags[:limit] + categoryHasMore["tags"] = true + } + if len(result.Paths) > limit { + result.Paths = result.Paths[:limit] + categoryHasMore["paths"] = true + } + + // Build pagination metadata + result.Pagination = PaginationMetadata{ + Limit: limit, + TotalMatches: totalMatches, + CategoryCounts: categoryCounts, + CategoryHasMore: categoryHasMore, + } + + return result +} diff --git a/tools/mcp-server/internal/tools/search_test.go b/tools/mcp-server/internal/tools/search_test.go new file mode 100644 index 0000000000..a5a00b6e35 --- /dev/null +++ b/tools/mcp-server/internal/tools/search_test.go @@ -0,0 +1,634 @@ +package tools + +import ( + "strings" + "testing" + + "github.com/mongodb/openapi/tools/mcp-server/internal/registry" + "github.com/oasdiff/kin-openapi/openapi3" +) + +// createTestSpec creates a comprehensive test OpenAPI spec for search testing. +func createTestSpec() *openapi3.T { + spec := &openapi3.T{ + OpenAPI: "3.0.0", + Info: &openapi3.Info{ + Title: "Test API", + Version: "1.0.0", + }, + Paths: &openapi3.Paths{}, + Components: &openapi3.Components{}, + Tags: []*openapi3.Tag{ + {Name: "Users", Description: "User management endpoints"}, + {Name: "Clusters", Description: "Cluster operations"}, + }, + } + + // Add operations + spec.Paths.Set("/users", &openapi3.PathItem{ + Get: &openapi3.Operation{ + OperationID: "getUsers", + Summary: "Get all users", + Description: "Retrieve a list of all users in the system", + Tags: []string{"Users"}, + }, + Post: &openapi3.Operation{ + OperationID: "createUser", + Summary: "Create a user", + Description: "Create a new user account", + Tags: []string{"Users"}, + }, + }) + + spec.Paths.Set("/users/{userId}", &openapi3.PathItem{ + Get: &openapi3.Operation{ + OperationID: "getUser", + Summary: "Get user by ID", + Description: "Retrieve a specific user by their ID", + Tags: []string{"Users"}, + }, + }) + + spec.Paths.Set("/clusters", &openapi3.PathItem{ + Post: &openapi3.Operation{ + OperationID: "createCluster", + Summary: "Create a new cluster", + Description: "Creates a new cluster in the project", + Tags: []string{"Clusters"}, + }, + Get: &openapi3.Operation{ + OperationID: "listClusters", + Summary: "List clusters", + Description: "Get all clusters in the project", + Tags: []string{"Clusters"}, + }, + }) + + spec.Paths.Set("/clusters/{clusterId}", &openapi3.PathItem{ + Get: &openapi3.Operation{ + OperationID: "getCluster", + Summary: "Get cluster details", + Description: "Retrieve details for a specific cluster", + Tags: []string{"Clusters"}, + }, + }) + + // Add schemas + spec.Components.Schemas = make(map[string]*openapi3.SchemaRef) + spec.Components.Schemas["User"] = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Description: "User account information", + Properties: map[string]*openapi3.SchemaRef{ + "userId": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + "username": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + "email": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + }, + }, + } + + spec.Components.Schemas["Cluster"] = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Description: "Cluster configuration", + Properties: map[string]*openapi3.SchemaRef{ + "clusterId": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + "name": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + "region": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + }, + }, + } + + spec.Components.Schemas["Database"] = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Description: "Database information", + Properties: map[string]*openapi3.SchemaRef{ + "databaseName": {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + }, + }, + } + + // Add parameters + spec.Components.Parameters = make(map[string]*openapi3.ParameterRef) + spec.Components.Parameters["userId"] = &openapi3.ParameterRef{ + Value: &openapi3.Parameter{ + Name: "userId", + In: "path", + Description: "Unique identifier for the user", + Required: true, + }, + } + + spec.Components.Parameters["clusterId"] = &openapi3.ParameterRef{ + Value: &openapi3.Parameter{ + Name: "clusterId", + In: "path", + Description: "Unique identifier for the cluster", + Required: true, + }, + } + + // Add responses + spec.Components.Responses = make(map[string]*openapi3.ResponseRef) + notFound := "Not Found" + spec.Components.Responses["NotFound"] = &openapi3.ResponseRef{ + Value: &openapi3.Response{ + Description: ¬Found, + }, + } + + unauthorized := "Unauthorized" + spec.Components.Responses["Unauthorized"] = &openapi3.ResponseRef{ + Value: &openapi3.Response{ + Description: &unauthorized, + }, + } + + return spec +} + +// setupTestRegistry creates a registry with the test spec loaded. +func setupTestRegistry(t *testing.T) *registry.Registry { + reg := registry.New() + spec := createTestSpec() + err := reg.Add("test-api", "/test/api.yaml", spec, nil) + if err != nil { + t.Fatalf("Failed to add spec: %v", err) + } + return reg +} + +// ExpectedResults defines expected search results for table-driven tests +type ExpectedResults struct { + Operations []string + Schemas []string + Parameters []string + Paths []string + Tags []string + TotalCount int +} + +// assertSearchResults verifies all aspects of search results +func assertSearchResults(t *testing.T, result SearchResult, expected ExpectedResults) { + t.Helper() + assertExactOperationIDs(t, result.Operations, expected.Operations) + assertExactSchemaNames(t, result.Schemas, expected.Schemas) + assertExactParameterNames(t, result.Parameters, expected.Parameters) + assertExactPaths(t, result.Paths, expected.Paths) + + if len(expected.Tags) > 0 { + if len(result.Tags) != len(expected.Tags) { + t.Errorf("Expected %d tags, got %d", len(expected.Tags), len(result.Tags)) + } + for i, expectedTag := range expected.Tags { + if i < len(result.Tags) && result.Tags[i].Name != expectedTag { + t.Errorf("Expected tag '%s', got '%s'", expectedTag, result.Tags[i].Name) + } + } + } + + if expected.TotalCount > 0 && result.Pagination.TotalMatches != expected.TotalCount { + t.Errorf("Expected totalMatches=%d, got %d", expected.TotalCount, result.Pagination.TotalMatches) + } +} + +// assertExactOperationIDs verifies the exact set of operation IDs +func assertExactOperationIDs(t *testing.T, operations []OperationMatch, expectedIDs []string) { + t.Helper() + if len(operations) != len(expectedIDs) { + t.Errorf("Expected %d operations, got %d", len(expectedIDs), len(operations)) + t.Logf("Got: %v", getOperationIDs(operations)) + t.Logf("Expected: %v", expectedIDs) + } + found := make(map[string]bool) + for _, op := range operations { + found[op.OperationID] = true + } + for _, expectedID := range expectedIDs { + if !found[expectedID] { + t.Errorf("Expected operation ID '%s' not found", expectedID) + } + } + expectedSet := make(map[string]bool) + for _, id := range expectedIDs { + expectedSet[id] = true + } + for _, op := range operations { + if !expectedSet[op.OperationID] { + t.Errorf("Unexpected operation ID '%s' found", op.OperationID) + } + } +} + +// assertExactSchemaNames verifies the exact set of schema names +func assertExactSchemaNames(t *testing.T, schemas []SchemaMatch, expectedNames []string) { + t.Helper() + if len(schemas) != len(expectedNames) { + t.Errorf("Expected %d schemas, got %d", len(expectedNames), len(schemas)) + t.Logf("Got: %v", getSchemaNames(schemas)) + t.Logf("Expected: %v", expectedNames) + } + found := make(map[string]bool) + for _, schema := range schemas { + found[schema.Name] = true + } + for _, expectedName := range expectedNames { + if !found[expectedName] { + t.Errorf("Expected schema '%s' not found", expectedName) + } + } +} + +// assertExactParameterNames verifies the exact set of parameter names +func assertExactParameterNames(t *testing.T, parameters []ParameterMatch, expectedNames []string) { + t.Helper() + if len(parameters) != len(expectedNames) { + t.Errorf("Expected %d parameters, got %d", len(expectedNames), len(parameters)) + t.Logf("Got: %v", getParameterNames(parameters)) + t.Logf("Expected: %v", expectedNames) + } + found := make(map[string]bool) + for _, param := range parameters { + found[param.Name] = true + } + for _, expectedName := range expectedNames { + if !found[expectedName] { + t.Errorf("Expected parameter '%s' not found", expectedName) + } + } +} + +// assertExactPaths verifies the exact set of paths +func assertExactPaths(t *testing.T, paths []PathMatch, expectedPaths []string) { + t.Helper() + if len(paths) != len(expectedPaths) { + t.Errorf("Expected %d paths, got %d", len(expectedPaths), len(paths)) + t.Logf("Got: %v", getPaths(paths)) + t.Logf("Expected: %v", expectedPaths) + } + found := make(map[string]bool) + for _, path := range paths { + found[path.Path] = true + } + for _, expectedPath := range expectedPaths { + if !found[expectedPath] { + t.Errorf("Expected path '%s' not found", expectedPath) + } + } +} + +// Helper functions to extract names/IDs for logging +func getOperationIDs(operations []OperationMatch) []string { + ids := make([]string, len(operations)) + for i, op := range operations { + ids[i] = op.OperationID + } + return ids +} + +func getSchemaNames(schemas []SchemaMatch) []string { + names := make([]string, len(schemas)) + for i, schema := range schemas { + names[i] = schema.Name + } + return names +} + +func getParameterNames(parameters []ParameterMatch) []string { + names := make([]string, len(parameters)) + for i, param := range parameters { + names[i] = param.Name + } + return names +} + +func getPaths(paths []PathMatch) []string { + pathStrs := make([]string, len(paths)) + for i, path := range paths { + pathStrs[i] = path.Path + } + return pathStrs +} + +func TestHandleSearch_Patterns(t *testing.T) { + reg := setupTestRegistry(t) + + tests := []struct { + name string + params SearchParams + expected ExpectedResults + checkFn func(*testing.T, SearchResult) // Optional additional checks + }{ + { + name: "pattern: user", + params: SearchParams{ + Alias: "test-api", + Pattern: "user", + }, + expected: ExpectedResults{ + Operations: []string{"getUsers", "createUser", "getUser"}, + Schemas: []string{"User"}, + Parameters: []string{"userId"}, + Paths: []string{"/users", "/users/{userId}"}, + Tags: []string{"Users"}, + TotalCount: 8, + }, + checkFn: func(t *testing.T, result SearchResult) { + // Verify matchedIn is populated + for _, op := range result.Operations { + if len(op.MatchedIn) == 0 { + t.Errorf("Expected matchedIn populated for %s", op.OperationID) + } + } + }, + }, + { + name: "pattern: cluster", + params: SearchParams{ + Alias: "test-api", + Pattern: "cluster", + }, + expected: ExpectedResults{ + Operations: []string{"createCluster", "listClusters", "getCluster"}, + Schemas: []string{"Cluster"}, + Parameters: []string{"clusterId"}, + Paths: []string{"/clusters", "/clusters/{clusterId}"}, + Tags: []string{"Clusters"}, + TotalCount: 8, + }, + checkFn: func(t *testing.T, result SearchResult) { + // Verify Cluster schema matched by both name and property + if len(result.Schemas) > 0 { + schema := result.Schemas[0] + hasName := false + hasProps := false + for _, field := range schema.MatchedIn { + if field == "name" { + hasName = true + } + if field == "properties" { + hasProps = true + } + } + if !hasName || !hasProps { + t.Errorf("Expected Cluster to match by name and properties, got: %v", schema.MatchedIn) + } + } + }, + }, + { + name: "case-insensitive: USER", + params: SearchParams{ + Alias: "test-api", + Pattern: "USER", + }, + expected: ExpectedResults{ + Operations: []string{"getUsers", "createUser", "getUser"}, + Schemas: []string{"User"}, + Parameters: []string{"userId"}, + Paths: []string{"/users", "/users/{userId}"}, + Tags: []string{"Users"}, + TotalCount: 8, + }, + }, + { + name: "case-sensitive: USER (no match)", + params: SearchParams{ + Alias: "test-api", + Pattern: "USER", + CaseSensitive: true, + }, + expected: ExpectedResults{ + Operations: []string{}, + Schemas: []string{}, + Parameters: []string{}, + Paths: []string{}, + Tags: []string{}, + TotalCount: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := handleSearch(reg, tt.params) + if err != nil { + t.Fatalf("handleSearch() failed: %v", err) + } + + if !result.Success { + t.Error("Expected success=true") + } + + assertSearchResults(t, result, tt.expected) + + if tt.checkFn != nil { + tt.checkFn(t, result) + } + }) + } +} + +func TestHandleSearch_InvalidRegex(t *testing.T) { + reg := setupTestRegistry(t) + + params := SearchParams{ + Alias: "test-api", + Pattern: "[invalid(regex", + } + + _, err := handleSearch(reg, params) + if err == nil { + t.Error("Expected error for invalid regex") + } + + if !strings.Contains(err.Error(), "invalid regex") && !strings.Contains(err.Error(), "error parsing regexp") { + t.Errorf("Expected regex error message, got: %v", err) + } +} + +func TestHandleSearch_Pagination(t *testing.T) { + reg := registry.New() + + spec := &openapi3.T{ + OpenAPI: "3.0.0", + Info: &openapi3.Info{ + Title: "Test API", + Version: "1.0.0", + }, + Paths: &openapi3.Paths{}, + } + + // Add multiple operations (15 to test limit of 5) + for i := 0; i < 15; i++ { + path := "/endpoint" + string(rune('a'+i)) + spec.Paths.Set(path, &openapi3.PathItem{ + Get: &openapi3.Operation{ + OperationID: "getEndpoint" + string(rune('A'+i)), + Summary: "Test endpoint " + string(rune('A'+i)), + }, + }) + } + + err := reg.Add("test-api", "/test/api.yaml", spec, nil) + if err != nil { + t.Fatalf("Failed to add spec: %v", err) + } + + // Test with limit=5 per category + params := SearchParams{ + Alias: "test-api", + Pattern: "endpoint", + Limit: 5, + } + + result, err := handleSearch(reg, params) + if err != nil { + t.Fatalf("handleSearch() failed: %v", err) + } + + // We should get 5 operations and 5 paths (limited) + if len(result.Operations) != 5 { + t.Errorf("Expected 5 operations (limited), got %d", len(result.Operations)) + } + + if len(result.Paths) != 5 { + t.Errorf("Expected 5 paths (limited), got %d", len(result.Paths)) + } + + // Total matches should be 30 (15 paths + 15 operations) + if result.Pagination.TotalMatches != 30 { + t.Errorf("Expected totalMatches=30, got %d", result.Pagination.TotalMatches) + } + + // Verify pagination metadata is exact + if result.Pagination.Limit != 5 { + t.Errorf("Expected limit=5, got %d", result.Pagination.Limit) + } + + if result.Pagination.TotalMatches != 30 { + t.Errorf("Expected totalMatches=30 (15 paths + 15 operations), got %d", result.Pagination.TotalMatches) + } + + // Verify category counts (before truncation) + if result.Pagination.CategoryCounts["operations"] != 15 { + t.Errorf("Expected categoryCounts['operations']=15, got %d", result.Pagination.CategoryCounts["operations"]) + } + + if result.Pagination.CategoryCounts["paths"] != 15 { + t.Errorf("Expected categoryCounts['paths']=15, got %d", result.Pagination.CategoryCounts["paths"]) + } + + // Should indicate more available for both categories + if !result.Pagination.CategoryHasMore["operations"] { + t.Error("Expected categoryHasMore['operations']=true (15 total, limit 5)") + } + + if !result.Pagination.CategoryHasMore["paths"] { + t.Error("Expected categoryHasMore['paths']=true (15 total, limit 5)") + } + + // Other categories should not be in categoryHasMore + if result.Pagination.CategoryHasMore["schemas"] { + t.Error("Expected categoryHasMore['schemas'] to be false (0 matches)") + } +} + +func TestHandleSearch_SearchInFilter(t *testing.T) { + reg := setupTestRegistry(t) + + tests := []struct { + name string + searchIn []string + expected ExpectedResults + }{ + { + name: "only schemas", + searchIn: []string{"schemas"}, + expected: ExpectedResults{ + Operations: []string{}, + Schemas: []string{"User"}, + Parameters: []string{}, + Paths: []string{}, + Tags: []string{}, + }, + }, + { + name: "only operations", + searchIn: []string{"operations"}, + expected: ExpectedResults{ + Operations: []string{"getUsers", "createUser", "getUser"}, + Schemas: []string{}, + Parameters: []string{}, + Paths: []string{}, + Tags: []string{}, + }, + }, + { + name: "operations and schemas", + searchIn: []string{"operations", "schemas"}, + expected: ExpectedResults{ + Operations: []string{"getUsers", "createUser", "getUser"}, + Schemas: []string{"User"}, + Parameters: []string{}, + Paths: []string{}, + Tags: []string{}, + }, + }, + { + name: "empty searchIn (all categories)", + searchIn: []string{}, + expected: ExpectedResults{ + Operations: []string{"getUsers", "createUser", "getUser"}, + Schemas: []string{"User"}, + Parameters: []string{"userId"}, + Paths: []string{"/users", "/users/{userId}"}, + Tags: []string{"Users"}, + TotalCount: 8, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := SearchParams{ + Alias: "test-api", + Pattern: "user", + SearchIn: tt.searchIn, + } + + result, err := handleSearch(reg, params) + if err != nil { + t.Fatalf("handleSearch() failed: %v", err) + } + + assertSearchResults(t, result, tt.expected) + }) + } +} + +func TestHandleSearch_InvalidSearchInCategory(t *testing.T) { + reg := setupTestRegistry(t) + + params := SearchParams{ + Alias: "test-api", + Pattern: "test", + SearchIn: []string{"operations", "invalid-category", "foo"}, + } + + _, err := handleSearch(reg, params) + if err == nil { + t.Error("Expected error for invalid searchIn categories") + } + + expectedError := "invalid searchIn categories" + if !strings.Contains(err.Error(), expectedError) { + t.Errorf("Expected error to contain '%s', got: %v", expectedError, err) + } + + // Should mention the invalid categories + if !strings.Contains(err.Error(), "invalid-category") || !strings.Contains(err.Error(), "foo") { + t.Errorf("Expected error to mention invalid categories, got: %v", err) + } +} diff --git a/tools/mcp-server/internal/tools/tools.go b/tools/mcp-server/internal/tools/tools.go index 2546727889..6967f57364 100644 --- a/tools/mcp-server/internal/tools/tools.go +++ b/tools/mcp-server/internal/tools/tools.go @@ -29,6 +29,13 @@ func Register(server *mcp.Server, reg *registry.Registry) { Description: "Export a loaded OpenAPI specification to a file", } mcp.AddTool(server, exportTool, makeExportHandler(reg)) + + // Register search tool + searchTool := &mcp.Tool{ + Name: "search", + Description: "Search for patterns in an OpenAPI specification using regular expressions", + } + mcp.AddTool(server, searchTool, makeSearchHandler(reg)) } // makeLoadHandler creates the handler for the load tool. @@ -54,3 +61,11 @@ func makeExportHandler(reg *registry.Registry) mcp.ToolHandlerFor[ExportParams, return nil, result, err } } + +// makeSearchHandler creates the handler for the search tool. +func makeSearchHandler(reg *registry.Registry) mcp.ToolHandlerFor[SearchParams, SearchResult] { + return func(_ context.Context, _ *mcp.CallToolRequest, params SearchParams) (*mcp.CallToolResult, SearchResult, error) { + result, err := handleSearch(reg, params) + return nil, result, err + } +} From aaa949008eb1dcf5a7a0c5ab4f4371c9f87eab4d Mon Sep 17 00:00:00 2001 From: Yeliz Henden Date: Thu, 9 Apr 2026 16:39:23 +0100 Subject: [PATCH 2/3] lining fixes --- tools/mcp-server/internal/tools/search.go | 12 +++--- .../mcp-server/internal/tools/search_test.go | 43 ++++++++++--------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/tools/mcp-server/internal/tools/search.go b/tools/mcp-server/internal/tools/search.go index 9a39866fc5..0479e23bd1 100644 --- a/tools/mcp-server/internal/tools/search.go +++ b/tools/mcp-server/internal/tools/search.go @@ -13,9 +13,9 @@ import ( type SearchParams struct { Alias string `json:"alias" jsonschema:"Alias of the spec to search"` Pattern string `json:"pattern" jsonschema:"Regular expression pattern to search for"` - SearchIn []string `json:"searchIn,omitempty" jsonschema:"Optional: categories to search (operations, schemas, parameters, responses, tags, paths). Default: all"` - CaseSensitive bool `json:"caseSensitive,omitempty" jsonschema:"Optional: case-sensitive search (default: false)"` - Limit int `json:"limit,omitempty" jsonschema:"Optional: maximum results per category (default: 100)"` + SearchIn []string `json:"searchIn,omitempty" jsonschema:"Optional: categories to search"` + CaseSensitive bool `json:"caseSensitive,omitempty" jsonschema:"Optional: case-sensitive search"` + Limit int `json:"limit,omitempty" jsonschema:"Optional: max results per category"` } // SearchResult is the result of a search operation. @@ -205,7 +205,7 @@ func handleSearch(reg *registry.Registry, params SearchParams) (SearchResult, er } // Apply per-category pagination - result = applyPagination(result, params.Limit) + applyPagination(&result, params.Limit) return result, nil } @@ -445,7 +445,7 @@ func searchPaths(spec *openapi3.T, re *regexp.Regexp) []PathMatch { } // applyPagination applies per-category limit to search results. -func applyPagination(result SearchResult, limit int) SearchResult { +func applyPagination(result *SearchResult, limit int) { // Store counts before truncation totalMatches := result.totalCount() categoryCounts := result.categoryCounts() @@ -484,6 +484,4 @@ func applyPagination(result SearchResult, limit int) SearchResult { CategoryCounts: categoryCounts, CategoryHasMore: categoryHasMore, } - - return result } diff --git a/tools/mcp-server/internal/tools/search_test.go b/tools/mcp-server/internal/tools/search_test.go index a5a00b6e35..01666e4074 100644 --- a/tools/mcp-server/internal/tools/search_test.go +++ b/tools/mcp-server/internal/tools/search_test.go @@ -150,6 +150,7 @@ func createTestSpec() *openapi3.T { // setupTestRegistry creates a registry with the test spec loaded. func setupTestRegistry(t *testing.T) *registry.Registry { + t.Helper() reg := registry.New() spec := createTestSpec() err := reg.Add("test-api", "/test/api.yaml", spec, nil) @@ -159,7 +160,7 @@ func setupTestRegistry(t *testing.T) *registry.Registry { return reg } -// ExpectedResults defines expected search results for table-driven tests +// ExpectedResults defines expected search results for table-driven tests. type ExpectedResults struct { Operations []string Schemas []string @@ -169,8 +170,8 @@ type ExpectedResults struct { TotalCount int } -// assertSearchResults verifies all aspects of search results -func assertSearchResults(t *testing.T, result SearchResult, expected ExpectedResults) { +// assertSearchResults verifies all aspects of search results. +func assertSearchResults(t *testing.T, result *SearchResult, expected *ExpectedResults) { t.Helper() assertExactOperationIDs(t, result.Operations, expected.Operations) assertExactSchemaNames(t, result.Schemas, expected.Schemas) @@ -193,7 +194,7 @@ func assertSearchResults(t *testing.T, result SearchResult, expected ExpectedRes } } -// assertExactOperationIDs verifies the exact set of operation IDs +// assertExactOperationIDs verifies the exact set of operation IDs. func assertExactOperationIDs(t *testing.T, operations []OperationMatch, expectedIDs []string) { t.Helper() if len(operations) != len(expectedIDs) { @@ -202,8 +203,8 @@ func assertExactOperationIDs(t *testing.T, operations []OperationMatch, expected t.Logf("Expected: %v", expectedIDs) } found := make(map[string]bool) - for _, op := range operations { - found[op.OperationID] = true + for i := range operations { + found[operations[i].OperationID] = true } for _, expectedID := range expectedIDs { if !found[expectedID] { @@ -214,14 +215,14 @@ func assertExactOperationIDs(t *testing.T, operations []OperationMatch, expected for _, id := range expectedIDs { expectedSet[id] = true } - for _, op := range operations { - if !expectedSet[op.OperationID] { - t.Errorf("Unexpected operation ID '%s' found", op.OperationID) + for i := range operations { + if !expectedSet[operations[i].OperationID] { + t.Errorf("Unexpected operation ID '%s' found", operations[i].OperationID) } } } -// assertExactSchemaNames verifies the exact set of schema names +// assertExactSchemaNames verifies the exact set of schema names. func assertExactSchemaNames(t *testing.T, schemas []SchemaMatch, expectedNames []string) { t.Helper() if len(schemas) != len(expectedNames) { @@ -240,7 +241,7 @@ func assertExactSchemaNames(t *testing.T, schemas []SchemaMatch, expectedNames [ } } -// assertExactParameterNames verifies the exact set of parameter names +// assertExactParameterNames verifies the exact set of parameter names. func assertExactParameterNames(t *testing.T, parameters []ParameterMatch, expectedNames []string) { t.Helper() if len(parameters) != len(expectedNames) { @@ -259,7 +260,7 @@ func assertExactParameterNames(t *testing.T, parameters []ParameterMatch, expect } } -// assertExactPaths verifies the exact set of paths +// assertExactPaths verifies the exact set of paths. func assertExactPaths(t *testing.T, paths []PathMatch, expectedPaths []string) { t.Helper() if len(paths) != len(expectedPaths) { @@ -278,11 +279,11 @@ func assertExactPaths(t *testing.T, paths []PathMatch, expectedPaths []string) { } } -// Helper functions to extract names/IDs for logging +// Helper functions to extract names/IDs for logging. func getOperationIDs(operations []OperationMatch) []string { ids := make([]string, len(operations)) - for i, op := range operations { - ids[i] = op.OperationID + for i := range operations { + ids[i] = operations[i].OperationID } return ids } @@ -335,10 +336,11 @@ func TestHandleSearch_Patterns(t *testing.T) { TotalCount: 8, }, checkFn: func(t *testing.T, result SearchResult) { + t.Helper() // Verify matchedIn is populated - for _, op := range result.Operations { - if len(op.MatchedIn) == 0 { - t.Errorf("Expected matchedIn populated for %s", op.OperationID) + for i := range result.Operations { + if len(result.Operations[i].MatchedIn) == 0 { + t.Errorf("Expected matchedIn populated for %s", result.Operations[i].OperationID) } } }, @@ -358,6 +360,7 @@ func TestHandleSearch_Patterns(t *testing.T) { TotalCount: 8, }, checkFn: func(t *testing.T, result SearchResult) { + t.Helper() // Verify Cluster schema matched by both name and property if len(result.Schemas) > 0 { schema := result.Schemas[0] @@ -421,7 +424,7 @@ func TestHandleSearch_Patterns(t *testing.T) { t.Error("Expected success=true") } - assertSearchResults(t, result, tt.expected) + assertSearchResults(t, &result, &tt.expected) if tt.checkFn != nil { tt.checkFn(t, result) @@ -603,7 +606,7 @@ func TestHandleSearch_SearchInFilter(t *testing.T) { t.Fatalf("handleSearch() failed: %v", err) } - assertSearchResults(t, result, tt.expected) + assertSearchResults(t, &result, &tt.expected) }) } } From b4078504710373f3bd98cec905e0d26b0653b83a Mon Sep 17 00:00:00 2001 From: Yeliz Henden Date: Fri, 10 Apr 2026 14:58:22 +0100 Subject: [PATCH 3/3] add more detailed tool description --- tools/mcp-server/internal/tools/tools.go | 27 +++++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/tools/mcp-server/internal/tools/tools.go b/tools/mcp-server/internal/tools/tools.go index 6967f57364..a2261aee2b 100644 --- a/tools/mcp-server/internal/tools/tools.go +++ b/tools/mcp-server/internal/tools/tools.go @@ -11,29 +11,40 @@ import ( func Register(server *mcp.Server, reg *registry.Registry) { // Register load tool loadTool := &mcp.Tool{ - Name: "load", - Description: "Load an OpenAPI specification from a file into memory", + Name: "load", + Description: "Load an OpenAPI specification file (JSON or YAML) into memory. " + + "Assigns an alias for easy reference in other commands. " + + "Validates the spec and makes it available for searching, exporting, and inspection. " + + "Required before using search or export tools on a spec.", } mcp.AddTool(server, loadTool, makeLoadHandler(reg)) // Register unload tool unloadTool := &mcp.Tool{ - Name: "unload", - Description: "Remove a loaded OpenAPI specification from memory", + Name: "unload", + Description: "Remove a previously loaded OpenAPI specification from memory by its alias. " + + "Frees up resources and removes the spec from the available list. " + + "Use this to clean up specs you no longer need to search or reference.", } mcp.AddTool(server, unloadTool, makeUnloadHandler(reg)) // Register export tool exportTool := &mcp.Tool{ - Name: "export", - Description: "Export a loaded OpenAPI specification to a file", + Name: "export", + Description: "Export a loaded OpenAPI specification to a file in JSON or YAML format. " + + "Useful for converting between formats, saving modified specs, or creating copies. " + + "Specify the output path and desired format (json/yaml).", } mcp.AddTool(server, exportTool, makeExportHandler(reg)) // Register search tool searchTool := &mcp.Tool{ - Name: "search", - Description: "Search for patterns in an OpenAPI specification using regular expressions", + Name: "search", + Description: "Search across OpenAPI specifications using regex patterns. " + + "Searches operations, schemas, parameters, responses, tags, and paths. " + + "Returns matches grouped by category with details on what matched (operationId, schema name, property names, etc.). " + + "Supports case-sensitive/insensitive search, category filtering, and per-category result limits. " + + "Use this to find specific endpoints, data structures, or API components by name or pattern.", } mcp.AddTool(server, searchTool, makeSearchHandler(reg)) }