fix(mcp): accept string type for all parameter types in MCP server (#8866)

This commit is contained in:
Rubén De la Torre Vico
2025-10-08 10:31:57 +02:00
committed by GitHub
parent c7d7ec9a3b
commit 5cfe140b7b

View File

@@ -11,7 +11,7 @@ import os
import re
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from typing import Optional
import requests
import yaml
@@ -23,10 +23,10 @@ class OpenAPIToMCPGenerator:
self,
spec_file: str,
custom_auth_module: Optional[str] = None,
exclude_patterns: Optional[List[str]] = None,
exclude_operations: Optional[List[str]] = None,
exclude_tags: Optional[List[str]] = None,
include_only_tags: Optional[List[str]] = None,
exclude_patterns: Optional[list[str]] = None,
exclude_operations: Optional[list[str]] = None,
exclude_tags: Optional[list[str]] = None,
include_only_tags: Optional[list[str]] = None,
config_file: Optional[str] = None,
):
"""
@@ -35,9 +35,9 @@ class OpenAPIToMCPGenerator:
Args:
spec_file: Path to OpenAPI specification file
custom_auth_module: Module path for custom authentication
exclude_patterns: List of regex patterns to exclude endpoints (matches against path)
exclude_operations: List of operation IDs to exclude
exclude_tags: List of tags to exclude
exclude_patterns: list of regex patterns to exclude endpoints (matches against path)
exclude_operations: list of operation IDs to exclude
exclude_tags: list of tags to exclude
include_only_tags: If specified, only include endpoints with these tags
config_file: Path to JSON configuration file for custom mappings
"""
@@ -54,26 +54,24 @@ class OpenAPIToMCPGenerator:
self.imports = set()
self.type_mapping = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"array": "str",
"object": "Dict[str, Any]",
"integer": "int | str",
"number": "float | str",
"boolean": "bool | str",
"array": "list[Any] | str",
"object": "dict[str, Any] | str",
}
def _load_config(self) -> Dict:
def _load_config(self) -> dict:
"""Load configuration from JSON file."""
try:
with open(self.config_file, "r") as f:
return json.load(f)
except FileNotFoundError:
# print(f"Warning: Config file {self.config_file} not found. Using defaults.")
return {}
except json.JSONDecodeError:
# print(f"Warning: Error parsing config file: {e}. Using defaults.")
return {}
def _load_spec(self) -> Dict:
def _load_spec(self) -> dict:
"""Load OpenAPI specification from file."""
with open(self.spec_file, "r") as f:
if self.spec_file.endswith(".yaml") or self.spec_file.endswith(".yml"):
@@ -81,7 +79,7 @@ class OpenAPIToMCPGenerator:
else:
return json.load(f)
def _get_endpoint_config(self, path: str, method: str) -> Dict:
def _get_endpoint_config(self, path: str, method: str) -> dict:
"""Get endpoint configuration from config file with pattern matching and inheritance.
Configuration resolution order (most to least specific):
@@ -153,7 +151,7 @@ class OpenAPIToMCPGenerator:
return merged_config
def _merge_configs(self, base_config: Dict, override_config: Dict) -> Dict:
def _merge_configs(self, base_config: dict, override_config: dict) -> dict:
"""Merge two configurations, with override_config taking precedence.
Special handling for parameters: merges parameter configurations deeply.
@@ -194,15 +192,19 @@ class OpenAPIToMCPGenerator:
name = f"op_{name}"
return name.lower()
def _get_python_type(self, schema: Dict) -> str:
"""Convert OpenAPI schema to Python type hint."""
def _get_python_type(self, schema: dict) -> tuple[str, str]:
"""Convert OpenAPI schema to Python type hint.
Returns:
Tuple of (type_hint, original_type) where original_type is used for casting
"""
if not schema:
return "Any"
return "Any", "any"
# Handle oneOf/anyOf/allOf schemas - these are typically objects
if "oneOf" in schema or "anyOf" in schema or "allOf" in schema:
# These are complex schemas, typically representing different object variants
return "Dict[str, Any]"
return "dict[str, Any] | str", "object"
schema_type = schema.get("type", "string")
@@ -210,30 +212,26 @@ class OpenAPIToMCPGenerator:
if "enum" in schema:
enum_values = schema["enum"]
if all(isinstance(v, str) for v in enum_values):
# Create Literal type for string enums
# Create Literal type for string enums - already strings, no casting needed
self.imports.add("from typing import Literal")
enum_str = ", ".join(f'"{v}"' for v in enum_values)
return f"Literal[{enum_str}]"
return f"Literal[{enum_str}]", "string"
else:
return self.type_mapping.get(schema_type, "Any")
return self.type_mapping.get(schema_type, "Any"), schema_type
# Handle arrays
if schema_type == "array":
return "str"
return "list[Any] | str", "array"
# Handle format specifications
if schema_type == "string":
format_type = schema.get("format", "")
if format_type in ["date", "date-time"]:
return "str" # Keep as string for API calls
elif format_type == "uuid":
return "str"
elif format_type == "email":
return "str"
if format_type in ["date", "date-time", "uuid", "email"]:
return "str", "string"
return self.type_mapping.get(schema_type, "Any")
return self.type_mapping.get(schema_type, "Any"), schema_type
def _resolve_ref(self, ref: str) -> Dict:
def _resolve_ref(self, ref: str) -> dict:
"""Resolve a $ref reference in the OpenAPI spec."""
if not ref.startswith("#/"):
return {}
@@ -249,8 +247,8 @@ class OpenAPIToMCPGenerator:
return resolved
def _extract_parameters(
self, operation: Dict, endpoint_config: Optional[Dict] = None
) -> List[Dict]:
self, operation: dict, endpoint_config: Optional[dict] = None
) -> list[dict]:
"""Extract and process parameters from an operation."""
parameters = []
@@ -264,13 +262,15 @@ class OpenAPIToMCPGenerator:
.replace("-", "_")
) # Also replace hyphens
type_hint, original_type = self._get_python_type(param.get("schema", {}))
param_info = {
"name": param.get("name", ""),
"python_name": python_name,
"in": param.get("in", "query"),
"required": param.get("required", False),
"description": param.get("description", ""),
"type": self._get_python_type(param.get("schema", {})),
"type": type_hint,
"original_type": original_type,
"original_schema": param.get("schema", {}),
}
@@ -323,7 +323,7 @@ class OpenAPIToMCPGenerator:
return parameters
def _extract_body_parameters(self, schema: Dict, is_required: bool) -> List[Dict]:
def _extract_body_parameters(self, schema: dict, is_required: bool) -> list[dict]:
"""Extract individual parameters from request body schema."""
parameters = []
@@ -346,6 +346,7 @@ class OpenAPIToMCPGenerator:
# Check if this field is required
is_field_required = prop_name in required_attrs
type_hint, original_type = self._get_python_type(prop_schema)
param_info = {
"name": prop_name, # Keep original name for API
"python_name": python_name,
@@ -355,7 +356,8 @@ class OpenAPIToMCPGenerator:
"description",
prop_schema.get("title", f"{prop_name} parameter"),
),
"type": self._get_python_type(prop_schema),
"type": type_hint,
"original_type": original_type,
"original_schema": prop_schema,
"resource_type": (
data["properties"]
@@ -383,6 +385,7 @@ class OpenAPIToMCPGenerator:
"required": is_rel_required,
"description": f"ID of the related {rel_name}",
"type": "str",
"original_type": "string",
"original_schema": rel_schema,
}
parameters.append(param_info)
@@ -396,7 +399,8 @@ class OpenAPIToMCPGenerator:
"in": "body",
"required": is_required,
"description": "Request body data",
"type": "Dict[str, Any]",
"type": "dict[str, Any] | str",
"original_type": "object",
"original_schema": schema,
}
)
@@ -405,11 +409,11 @@ class OpenAPIToMCPGenerator:
def _generate_docstring(
self,
operation: Dict,
parameters: List[Dict],
operation: dict,
parameters: list[dict],
path: str,
method: str,
endpoint_config: Optional[Dict] = None,
endpoint_config: Optional[dict] = None,
) -> str:
"""Generate a comprehensive docstring for the tool function."""
lines = []
@@ -447,7 +451,7 @@ class OpenAPIToMCPGenerator:
lines.append(" Args:")
for param in parameters:
# Use custom description if available
param_desc = param["description"] or "No description provided"
param_desc = param["description"] or "Self-explanatory parameter"
# Handle multi-line descriptions properly
required_text = "(required)" if param["required"] else "(optional)"
@@ -481,13 +485,13 @@ class OpenAPIToMCPGenerator:
# Returns section
lines.append("")
lines.append(" Returns:")
lines.append(" Dict containing the API response")
lines.append(" dict containing the API response")
lines.append(' """')
return "\n".join(lines)
def _generate_function_signature(
self, func_name: str, parameters: List[Dict]
self, func_name: str, parameters: list[dict]
) -> str:
"""Generate the function signature with proper type hints."""
# Sort parameters: required first, then optional
@@ -506,12 +510,38 @@ class OpenAPIToMCPGenerator:
if param_strings:
params_str = ",\n".join(param_strings)
return f"async def {func_name}(\n{params_str}\n) -> Dict[str, Any]:"
return f"async def {func_name}(\n{params_str}\n) -> dict[str, Any]:"
else:
return f"async def {func_name}() -> Dict[str, Any]:"
return f"async def {func_name}() -> dict[str, Any]:"
def _get_cast_expression(self, param: dict) -> str:
"""Generate type casting expression for a parameter.
Args:
param: Parameter dict with 'python_name' and 'original_type'
Returns:
Expression string that casts the parameter value to the correct type
"""
python_name = param["python_name"]
original_type = param.get("original_type", "string")
if original_type == "integer":
return f"int({python_name}) if isinstance({python_name}, str) else {python_name}"
elif original_type == "number":
return f"float({python_name}) if isinstance({python_name}, str) else {python_name}"
elif original_type == "boolean":
return f"({python_name}.lower() in ['true', '1', 'yes'] if isinstance({python_name}, str) else bool({python_name}))"
elif original_type == "array":
return f"json.loads({python_name}) if isinstance({python_name}, str) else {python_name}"
elif original_type == "object":
return f"json.loads({python_name}) if isinstance({python_name}, str) else {python_name}"
else:
# string or any other type - no casting needed
return python_name
def _generate_function_body(
self, path: str, method: str, parameters: List[Dict], operation_id: str
self, path: str, method: str, parameters: list[dict], operation_id: str
) -> str:
"""Generate the function body for making API calls."""
lines = []
@@ -529,26 +559,28 @@ class OpenAPIToMCPGenerator:
path_params = [p for p in parameters if p["in"] == "path"]
body_params = [p for p in parameters if p["in"] == "body"]
# Add json import if needed for object or array type casting
if any(p.get("original_type") in ["object", "array"] for p in parameters):
self.imports.add("import json")
# Build query parameters
if query_params:
lines.append(" params = {}")
for param in query_params:
cast_expr = self._get_cast_expression(param)
if param["required"]:
lines.append(
f" params['{param['name']}'] = {param['python_name']}"
)
lines.append(f" params['{param['name']}'] = {cast_expr}")
else:
lines.append(f" if {param['python_name']} is not None:")
lines.append(
f" params['{param['name']}'] = {param['python_name']}"
)
lines.append(f" params['{param['name']}'] = {cast_expr}")
lines.append("")
# Build path with path parameters
final_path = path
for param in path_params:
cast_expr = self._get_cast_expression(param)
lines.append(
f" path = '{path}'.replace('{{{param['name']}}}', str({param['python_name']}))"
f" path = '{path}'.replace('{{{param['name']}}}', str({cast_expr}))"
)
final_path = "path"
@@ -556,8 +588,9 @@ class OpenAPIToMCPGenerator:
if body_params:
# Check if we have individual params or a single body param
if len(body_params) == 1 and body_params[0]["python_name"] == "body":
# Single body parameter - use it directly
lines.append(" request_body = body")
# Single body parameter - use it directly with casting
cast_expr = self._get_cast_expression(body_params[0])
lines.append(f" request_body = {cast_expr}")
else:
# Get resource type from first body param (they should all have the same)
resource_type = (
@@ -598,16 +631,17 @@ class OpenAPIToMCPGenerator:
lines.append("")
lines.append(" # Add attributes")
for param in attribute_params:
cast_expr = self._get_cast_expression(param)
if param["required"]:
lines.append(
f' request_body["data"]["attributes"]["{param["name"]}"] = {param["python_name"]}'
f' request_body["data"]["attributes"]["{param["name"]}"] = {cast_expr}'
)
else:
lines.append(
f" if {param['python_name']} is not None:"
)
lines.append(
f' request_body["data"]["attributes"]["{param["name"]}"] = {param["python_name"]}'
f' request_body["data"]["attributes"]["{param["name"]}"] = {cast_expr}'
)
if relationship_params:
@@ -616,15 +650,14 @@ class OpenAPIToMCPGenerator:
lines.append(' request_body["data"]["relationships"] = {}')
for param in relationship_params:
rel_name = param["python_name"].replace("_id", "")
cast_expr = self._get_cast_expression(param)
if param["required"]:
lines.append(
f' request_body["data"]["relationships"]["{rel_name}"] = {{'
)
lines.append(' "data": {')
lines.append(f' "type": "{rel_name}s",')
lines.append(
f' "id": {param["python_name"]}'
)
lines.append(f' "id": {cast_expr}')
lines.append(" }")
lines.append(" }")
else:
@@ -636,24 +669,22 @@ class OpenAPIToMCPGenerator:
)
lines.append(' "data": {')
lines.append(f' "type": "{rel_name}s",')
lines.append(
f' "id": {param["python_name"]}'
)
lines.append(f' "id": {cast_expr}')
lines.append(" }")
lines.append(" }")
lines.append("")
# Prepare HTTP client call
lines.append(" async with httpx.AsyncClient() as client:")
# Build the request URL
url_line = (
f'f"{{auth_manager.base_url}}{{{final_path}}}"'
if final_path == "path"
else f'f"{{auth_manager.base_url}}{path}"'
)
lines.append(f" url = {url_line}")
lines.append("")
# Build the request
request_params = [
(
f'f"{{auth_manager.base_url}}{{{final_path}}}"'
if final_path == "path"
else f'f"{{auth_manager.base_url}}{path}"'
)
]
# Build request parameters
request_params = ["url"]
if self.custom_auth_module:
request_params.append("headers=auth_manager.get_headers(token)")
@@ -664,24 +695,21 @@ class OpenAPIToMCPGenerator:
if body_params:
request_params.append("json=request_body")
request_params.append("timeout=30.0")
params_str = ",\n ".join(request_params)
params_str = ",\n ".join(request_params)
lines.append(f" response = await client.{method}(")
lines.append(f" {params_str}")
lines.append(" )")
lines.append(" response.raise_for_status()")
lines.append(f" response = await prowler_app_client.{method}(")
lines.append(f" {params_str}")
lines.append(" )")
lines.append(" response.raise_for_status()")
lines.append("")
# Parse response
lines.append(" data = response.json()")
lines.append(" data = response.json()")
lines.append("")
lines.append(" return {")
lines.append(' "success": True,')
lines.append(' "data": data.get("data", data),')
lines.append(' "meta": data.get("meta", {})')
lines.append(" }")
lines.append(" return {")
lines.append(' "success": True,')
lines.append(' "data": data.get("data", data),')
lines.append(" }")
lines.append("")
# Exception handling
@@ -695,7 +723,7 @@ class OpenAPIToMCPGenerator:
return "\n".join(lines)
def _should_exclude_endpoint(self, path: str, operation: Dict) -> bool:
def _should_exclude_endpoint(self, path: str, operation: dict) -> bool:
"""
Determine if an endpoint should be excluded from generation.
@@ -750,7 +778,7 @@ class OpenAPIToMCPGenerator:
output_lines.append("")
# Add imports
self.imports.add("from typing import Dict, Any, Optional")
self.imports.add("from typing import Any, Optional")
self.imports.add("import httpx")
self.imports.add("from fastmcp import FastMCP")
@@ -848,6 +876,11 @@ class OpenAPIToMCPGenerator:
output_lines.append("# Initialize authentication manager")
output_lines.append("auth_manager = ProwlerAppAuth()")
output_lines.append("")
output_lines.append("# Initialize HTTP client")
output_lines.append("prowler_app_client = httpx.AsyncClient(")
output_lines.append(" timeout=30.0,")
output_lines.append(")")
output_lines.append("")
# Write tools grouped by tag
for tag, tools in tools_by_tag.items():
@@ -867,45 +900,6 @@ class OpenAPIToMCPGenerator:
"""Save the generated code to a file."""
generated_code = self.generate_tools()
Path(output_file).write_text(generated_code)
# print(f"Generated FastMCP server saved to: {output_file}")
# # Report statistics
# paths = self.spec.get("paths", {})
# total_endpoints = sum(
# len(
# [m for m in ["get", "post", "put", "patch", "delete"] if m in path_item]
# )
# for path_item in paths.values()
# )
# # Count excluded endpoints by reason
# excluded_count = 0
# deprecated_count = 0
# for path, path_item in paths.items():
# for method in ["get", "post", "put", "patch", "delete"]:
# if method in path_item:
# operation = path_item[method]
# if operation.get("deprecated", False):
# deprecated_count += 1
# if self._should_exclude_endpoint(path, operation):
# excluded_count += 1
# generated_count = total_endpoints - excluded_count
# print(f"Total endpoints in spec: {total_endpoints}")
# print(f"Endpoints excluded: {excluded_count}")
# if deprecated_count > 0:
# print(f" - Deprecated: {deprecated_count}")
# print(f"Endpoints generated: {generated_count}")
# Show exclusion rules if any
# if self.exclude_patterns:
# # print(f"Excluded patterns: {self.exclude_patterns}")
# if self.exclude_operations:
# # print(f"Excluded operations: {self.exclude_operations}")
# if self.exclude_tags:
# # print(f"Excluded tags: {self.exclude_tags}")
# if self.include_only_tags:
# # print(f"Including only tags: {self.include_only_tags}")
def generate_server_file():