mirror of
https://github.com/prowler-cloud/prowler.git
synced 2025-12-19 05:17:47 +00:00
fix(mcp): accept string type for all parameter types in MCP server (#8866)
This commit is contained in:
committed by
GitHub
parent
c7d7ec9a3b
commit
5cfe140b7b
@@ -11,7 +11,7 @@ import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -23,10 +23,10 @@ class OpenAPIToMCPGenerator:
|
||||
self,
|
||||
spec_file: str,
|
||||
custom_auth_module: Optional[str] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
exclude_operations: Optional[List[str]] = None,
|
||||
exclude_tags: Optional[List[str]] = None,
|
||||
include_only_tags: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[list[str]] = None,
|
||||
exclude_operations: Optional[list[str]] = None,
|
||||
exclude_tags: Optional[list[str]] = None,
|
||||
include_only_tags: Optional[list[str]] = None,
|
||||
config_file: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@@ -35,9 +35,9 @@ class OpenAPIToMCPGenerator:
|
||||
Args:
|
||||
spec_file: Path to OpenAPI specification file
|
||||
custom_auth_module: Module path for custom authentication
|
||||
exclude_patterns: List of regex patterns to exclude endpoints (matches against path)
|
||||
exclude_operations: List of operation IDs to exclude
|
||||
exclude_tags: List of tags to exclude
|
||||
exclude_patterns: list of regex patterns to exclude endpoints (matches against path)
|
||||
exclude_operations: list of operation IDs to exclude
|
||||
exclude_tags: list of tags to exclude
|
||||
include_only_tags: If specified, only include endpoints with these tags
|
||||
config_file: Path to JSON configuration file for custom mappings
|
||||
"""
|
||||
@@ -54,26 +54,24 @@ class OpenAPIToMCPGenerator:
|
||||
self.imports = set()
|
||||
self.type_mapping = {
|
||||
"string": "str",
|
||||
"integer": "int",
|
||||
"number": "float",
|
||||
"boolean": "bool",
|
||||
"array": "str",
|
||||
"object": "Dict[str, Any]",
|
||||
"integer": "int | str",
|
||||
"number": "float | str",
|
||||
"boolean": "bool | str",
|
||||
"array": "list[Any] | str",
|
||||
"object": "dict[str, Any] | str",
|
||||
}
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
def _load_config(self) -> dict:
|
||||
"""Load configuration from JSON file."""
|
||||
try:
|
||||
with open(self.config_file, "r") as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
# print(f"Warning: Config file {self.config_file} not found. Using defaults.")
|
||||
return {}
|
||||
except json.JSONDecodeError:
|
||||
# print(f"Warning: Error parsing config file: {e}. Using defaults.")
|
||||
return {}
|
||||
|
||||
def _load_spec(self) -> Dict:
|
||||
def _load_spec(self) -> dict:
|
||||
"""Load OpenAPI specification from file."""
|
||||
with open(self.spec_file, "r") as f:
|
||||
if self.spec_file.endswith(".yaml") or self.spec_file.endswith(".yml"):
|
||||
@@ -81,7 +79,7 @@ class OpenAPIToMCPGenerator:
|
||||
else:
|
||||
return json.load(f)
|
||||
|
||||
def _get_endpoint_config(self, path: str, method: str) -> Dict:
|
||||
def _get_endpoint_config(self, path: str, method: str) -> dict:
|
||||
"""Get endpoint configuration from config file with pattern matching and inheritance.
|
||||
|
||||
Configuration resolution order (most to least specific):
|
||||
@@ -153,7 +151,7 @@ class OpenAPIToMCPGenerator:
|
||||
|
||||
return merged_config
|
||||
|
||||
def _merge_configs(self, base_config: Dict, override_config: Dict) -> Dict:
|
||||
def _merge_configs(self, base_config: dict, override_config: dict) -> dict:
|
||||
"""Merge two configurations, with override_config taking precedence.
|
||||
|
||||
Special handling for parameters: merges parameter configurations deeply.
|
||||
@@ -194,15 +192,19 @@ class OpenAPIToMCPGenerator:
|
||||
name = f"op_{name}"
|
||||
return name.lower()
|
||||
|
||||
def _get_python_type(self, schema: Dict) -> str:
|
||||
"""Convert OpenAPI schema to Python type hint."""
|
||||
def _get_python_type(self, schema: dict) -> tuple[str, str]:
|
||||
"""Convert OpenAPI schema to Python type hint.
|
||||
|
||||
Returns:
|
||||
Tuple of (type_hint, original_type) where original_type is used for casting
|
||||
"""
|
||||
if not schema:
|
||||
return "Any"
|
||||
return "Any", "any"
|
||||
|
||||
# Handle oneOf/anyOf/allOf schemas - these are typically objects
|
||||
if "oneOf" in schema or "anyOf" in schema or "allOf" in schema:
|
||||
# These are complex schemas, typically representing different object variants
|
||||
return "Dict[str, Any]"
|
||||
return "dict[str, Any] | str", "object"
|
||||
|
||||
schema_type = schema.get("type", "string")
|
||||
|
||||
@@ -210,30 +212,26 @@ class OpenAPIToMCPGenerator:
|
||||
if "enum" in schema:
|
||||
enum_values = schema["enum"]
|
||||
if all(isinstance(v, str) for v in enum_values):
|
||||
# Create Literal type for string enums
|
||||
# Create Literal type for string enums - already strings, no casting needed
|
||||
self.imports.add("from typing import Literal")
|
||||
enum_str = ", ".join(f'"{v}"' for v in enum_values)
|
||||
return f"Literal[{enum_str}]"
|
||||
return f"Literal[{enum_str}]", "string"
|
||||
else:
|
||||
return self.type_mapping.get(schema_type, "Any")
|
||||
return self.type_mapping.get(schema_type, "Any"), schema_type
|
||||
|
||||
# Handle arrays
|
||||
if schema_type == "array":
|
||||
return "str"
|
||||
return "list[Any] | str", "array"
|
||||
|
||||
# Handle format specifications
|
||||
if schema_type == "string":
|
||||
format_type = schema.get("format", "")
|
||||
if format_type in ["date", "date-time"]:
|
||||
return "str" # Keep as string for API calls
|
||||
elif format_type == "uuid":
|
||||
return "str"
|
||||
elif format_type == "email":
|
||||
return "str"
|
||||
if format_type in ["date", "date-time", "uuid", "email"]:
|
||||
return "str", "string"
|
||||
|
||||
return self.type_mapping.get(schema_type, "Any")
|
||||
return self.type_mapping.get(schema_type, "Any"), schema_type
|
||||
|
||||
def _resolve_ref(self, ref: str) -> Dict:
|
||||
def _resolve_ref(self, ref: str) -> dict:
|
||||
"""Resolve a $ref reference in the OpenAPI spec."""
|
||||
if not ref.startswith("#/"):
|
||||
return {}
|
||||
@@ -249,8 +247,8 @@ class OpenAPIToMCPGenerator:
|
||||
return resolved
|
||||
|
||||
def _extract_parameters(
|
||||
self, operation: Dict, endpoint_config: Optional[Dict] = None
|
||||
) -> List[Dict]:
|
||||
self, operation: dict, endpoint_config: Optional[dict] = None
|
||||
) -> list[dict]:
|
||||
"""Extract and process parameters from an operation."""
|
||||
parameters = []
|
||||
|
||||
@@ -264,13 +262,15 @@ class OpenAPIToMCPGenerator:
|
||||
.replace("-", "_")
|
||||
) # Also replace hyphens
|
||||
|
||||
type_hint, original_type = self._get_python_type(param.get("schema", {}))
|
||||
param_info = {
|
||||
"name": param.get("name", ""),
|
||||
"python_name": python_name,
|
||||
"in": param.get("in", "query"),
|
||||
"required": param.get("required", False),
|
||||
"description": param.get("description", ""),
|
||||
"type": self._get_python_type(param.get("schema", {})),
|
||||
"type": type_hint,
|
||||
"original_type": original_type,
|
||||
"original_schema": param.get("schema", {}),
|
||||
}
|
||||
|
||||
@@ -323,7 +323,7 @@ class OpenAPIToMCPGenerator:
|
||||
|
||||
return parameters
|
||||
|
||||
def _extract_body_parameters(self, schema: Dict, is_required: bool) -> List[Dict]:
|
||||
def _extract_body_parameters(self, schema: dict, is_required: bool) -> list[dict]:
|
||||
"""Extract individual parameters from request body schema."""
|
||||
parameters = []
|
||||
|
||||
@@ -346,6 +346,7 @@ class OpenAPIToMCPGenerator:
|
||||
# Check if this field is required
|
||||
is_field_required = prop_name in required_attrs
|
||||
|
||||
type_hint, original_type = self._get_python_type(prop_schema)
|
||||
param_info = {
|
||||
"name": prop_name, # Keep original name for API
|
||||
"python_name": python_name,
|
||||
@@ -355,7 +356,8 @@ class OpenAPIToMCPGenerator:
|
||||
"description",
|
||||
prop_schema.get("title", f"{prop_name} parameter"),
|
||||
),
|
||||
"type": self._get_python_type(prop_schema),
|
||||
"type": type_hint,
|
||||
"original_type": original_type,
|
||||
"original_schema": prop_schema,
|
||||
"resource_type": (
|
||||
data["properties"]
|
||||
@@ -383,6 +385,7 @@ class OpenAPIToMCPGenerator:
|
||||
"required": is_rel_required,
|
||||
"description": f"ID of the related {rel_name}",
|
||||
"type": "str",
|
||||
"original_type": "string",
|
||||
"original_schema": rel_schema,
|
||||
}
|
||||
parameters.append(param_info)
|
||||
@@ -396,7 +399,8 @@ class OpenAPIToMCPGenerator:
|
||||
"in": "body",
|
||||
"required": is_required,
|
||||
"description": "Request body data",
|
||||
"type": "Dict[str, Any]",
|
||||
"type": "dict[str, Any] | str",
|
||||
"original_type": "object",
|
||||
"original_schema": schema,
|
||||
}
|
||||
)
|
||||
@@ -405,11 +409,11 @@ class OpenAPIToMCPGenerator:
|
||||
|
||||
def _generate_docstring(
|
||||
self,
|
||||
operation: Dict,
|
||||
parameters: List[Dict],
|
||||
operation: dict,
|
||||
parameters: list[dict],
|
||||
path: str,
|
||||
method: str,
|
||||
endpoint_config: Optional[Dict] = None,
|
||||
endpoint_config: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""Generate a comprehensive docstring for the tool function."""
|
||||
lines = []
|
||||
@@ -447,7 +451,7 @@ class OpenAPIToMCPGenerator:
|
||||
lines.append(" Args:")
|
||||
for param in parameters:
|
||||
# Use custom description if available
|
||||
param_desc = param["description"] or "No description provided"
|
||||
param_desc = param["description"] or "Self-explanatory parameter"
|
||||
|
||||
# Handle multi-line descriptions properly
|
||||
required_text = "(required)" if param["required"] else "(optional)"
|
||||
@@ -481,13 +485,13 @@ class OpenAPIToMCPGenerator:
|
||||
# Returns section
|
||||
lines.append("")
|
||||
lines.append(" Returns:")
|
||||
lines.append(" Dict containing the API response")
|
||||
lines.append(" dict containing the API response")
|
||||
|
||||
lines.append(' """')
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_function_signature(
|
||||
self, func_name: str, parameters: List[Dict]
|
||||
self, func_name: str, parameters: list[dict]
|
||||
) -> str:
|
||||
"""Generate the function signature with proper type hints."""
|
||||
# Sort parameters: required first, then optional
|
||||
@@ -506,12 +510,38 @@ class OpenAPIToMCPGenerator:
|
||||
|
||||
if param_strings:
|
||||
params_str = ",\n".join(param_strings)
|
||||
return f"async def {func_name}(\n{params_str}\n) -> Dict[str, Any]:"
|
||||
return f"async def {func_name}(\n{params_str}\n) -> dict[str, Any]:"
|
||||
else:
|
||||
return f"async def {func_name}() -> Dict[str, Any]:"
|
||||
return f"async def {func_name}() -> dict[str, Any]:"
|
||||
|
||||
def _get_cast_expression(self, param: dict) -> str:
|
||||
"""Generate type casting expression for a parameter.
|
||||
|
||||
Args:
|
||||
param: Parameter dict with 'python_name' and 'original_type'
|
||||
|
||||
Returns:
|
||||
Expression string that casts the parameter value to the correct type
|
||||
"""
|
||||
python_name = param["python_name"]
|
||||
original_type = param.get("original_type", "string")
|
||||
|
||||
if original_type == "integer":
|
||||
return f"int({python_name}) if isinstance({python_name}, str) else {python_name}"
|
||||
elif original_type == "number":
|
||||
return f"float({python_name}) if isinstance({python_name}, str) else {python_name}"
|
||||
elif original_type == "boolean":
|
||||
return f"({python_name}.lower() in ['true', '1', 'yes'] if isinstance({python_name}, str) else bool({python_name}))"
|
||||
elif original_type == "array":
|
||||
return f"json.loads({python_name}) if isinstance({python_name}, str) else {python_name}"
|
||||
elif original_type == "object":
|
||||
return f"json.loads({python_name}) if isinstance({python_name}, str) else {python_name}"
|
||||
else:
|
||||
# string or any other type - no casting needed
|
||||
return python_name
|
||||
|
||||
def _generate_function_body(
|
||||
self, path: str, method: str, parameters: List[Dict], operation_id: str
|
||||
self, path: str, method: str, parameters: list[dict], operation_id: str
|
||||
) -> str:
|
||||
"""Generate the function body for making API calls."""
|
||||
lines = []
|
||||
@@ -529,26 +559,28 @@ class OpenAPIToMCPGenerator:
|
||||
path_params = [p for p in parameters if p["in"] == "path"]
|
||||
body_params = [p for p in parameters if p["in"] == "body"]
|
||||
|
||||
# Add json import if needed for object or array type casting
|
||||
if any(p.get("original_type") in ["object", "array"] for p in parameters):
|
||||
self.imports.add("import json")
|
||||
|
||||
# Build query parameters
|
||||
if query_params:
|
||||
lines.append(" params = {}")
|
||||
for param in query_params:
|
||||
cast_expr = self._get_cast_expression(param)
|
||||
if param["required"]:
|
||||
lines.append(
|
||||
f" params['{param['name']}'] = {param['python_name']}"
|
||||
)
|
||||
lines.append(f" params['{param['name']}'] = {cast_expr}")
|
||||
else:
|
||||
lines.append(f" if {param['python_name']} is not None:")
|
||||
lines.append(
|
||||
f" params['{param['name']}'] = {param['python_name']}"
|
||||
)
|
||||
lines.append(f" params['{param['name']}'] = {cast_expr}")
|
||||
lines.append("")
|
||||
|
||||
# Build path with path parameters
|
||||
final_path = path
|
||||
for param in path_params:
|
||||
cast_expr = self._get_cast_expression(param)
|
||||
lines.append(
|
||||
f" path = '{path}'.replace('{{{param['name']}}}', str({param['python_name']}))"
|
||||
f" path = '{path}'.replace('{{{param['name']}}}', str({cast_expr}))"
|
||||
)
|
||||
final_path = "path"
|
||||
|
||||
@@ -556,8 +588,9 @@ class OpenAPIToMCPGenerator:
|
||||
if body_params:
|
||||
# Check if we have individual params or a single body param
|
||||
if len(body_params) == 1 and body_params[0]["python_name"] == "body":
|
||||
# Single body parameter - use it directly
|
||||
lines.append(" request_body = body")
|
||||
# Single body parameter - use it directly with casting
|
||||
cast_expr = self._get_cast_expression(body_params[0])
|
||||
lines.append(f" request_body = {cast_expr}")
|
||||
else:
|
||||
# Get resource type from first body param (they should all have the same)
|
||||
resource_type = (
|
||||
@@ -598,16 +631,17 @@ class OpenAPIToMCPGenerator:
|
||||
lines.append("")
|
||||
lines.append(" # Add attributes")
|
||||
for param in attribute_params:
|
||||
cast_expr = self._get_cast_expression(param)
|
||||
if param["required"]:
|
||||
lines.append(
|
||||
f' request_body["data"]["attributes"]["{param["name"]}"] = {param["python_name"]}'
|
||||
f' request_body["data"]["attributes"]["{param["name"]}"] = {cast_expr}'
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f" if {param['python_name']} is not None:"
|
||||
)
|
||||
lines.append(
|
||||
f' request_body["data"]["attributes"]["{param["name"]}"] = {param["python_name"]}'
|
||||
f' request_body["data"]["attributes"]["{param["name"]}"] = {cast_expr}'
|
||||
)
|
||||
|
||||
if relationship_params:
|
||||
@@ -616,15 +650,14 @@ class OpenAPIToMCPGenerator:
|
||||
lines.append(' request_body["data"]["relationships"] = {}')
|
||||
for param in relationship_params:
|
||||
rel_name = param["python_name"].replace("_id", "")
|
||||
cast_expr = self._get_cast_expression(param)
|
||||
if param["required"]:
|
||||
lines.append(
|
||||
f' request_body["data"]["relationships"]["{rel_name}"] = {{'
|
||||
)
|
||||
lines.append(' "data": {')
|
||||
lines.append(f' "type": "{rel_name}s",')
|
||||
lines.append(
|
||||
f' "id": {param["python_name"]}'
|
||||
)
|
||||
lines.append(f' "id": {cast_expr}')
|
||||
lines.append(" }")
|
||||
lines.append(" }")
|
||||
else:
|
||||
@@ -636,24 +669,22 @@ class OpenAPIToMCPGenerator:
|
||||
)
|
||||
lines.append(' "data": {')
|
||||
lines.append(f' "type": "{rel_name}s",')
|
||||
lines.append(
|
||||
f' "id": {param["python_name"]}'
|
||||
)
|
||||
lines.append(f' "id": {cast_expr}')
|
||||
lines.append(" }")
|
||||
lines.append(" }")
|
||||
lines.append("")
|
||||
|
||||
# Prepare HTTP client call
|
||||
lines.append(" async with httpx.AsyncClient() as client:")
|
||||
# Build the request URL
|
||||
url_line = (
|
||||
f'f"{{auth_manager.base_url}}{{{final_path}}}"'
|
||||
if final_path == "path"
|
||||
else f'f"{{auth_manager.base_url}}{path}"'
|
||||
)
|
||||
lines.append(f" url = {url_line}")
|
||||
lines.append("")
|
||||
|
||||
# Build the request
|
||||
request_params = [
|
||||
(
|
||||
f'f"{{auth_manager.base_url}}{{{final_path}}}"'
|
||||
if final_path == "path"
|
||||
else f'f"{{auth_manager.base_url}}{path}"'
|
||||
)
|
||||
]
|
||||
# Build request parameters
|
||||
request_params = ["url"]
|
||||
|
||||
if self.custom_auth_module:
|
||||
request_params.append("headers=auth_manager.get_headers(token)")
|
||||
@@ -664,24 +695,21 @@ class OpenAPIToMCPGenerator:
|
||||
if body_params:
|
||||
request_params.append("json=request_body")
|
||||
|
||||
request_params.append("timeout=30.0")
|
||||
params_str = ",\n ".join(request_params)
|
||||
|
||||
params_str = ",\n ".join(request_params)
|
||||
|
||||
lines.append(f" response = await client.{method}(")
|
||||
lines.append(f" {params_str}")
|
||||
lines.append(" )")
|
||||
lines.append(" response.raise_for_status()")
|
||||
lines.append(f" response = await prowler_app_client.{method}(")
|
||||
lines.append(f" {params_str}")
|
||||
lines.append(" )")
|
||||
lines.append(" response.raise_for_status()")
|
||||
lines.append("")
|
||||
|
||||
# Parse response
|
||||
lines.append(" data = response.json()")
|
||||
lines.append(" data = response.json()")
|
||||
lines.append("")
|
||||
lines.append(" return {")
|
||||
lines.append(' "success": True,')
|
||||
lines.append(' "data": data.get("data", data),')
|
||||
lines.append(' "meta": data.get("meta", {})')
|
||||
lines.append(" }")
|
||||
lines.append(" return {")
|
||||
lines.append(' "success": True,')
|
||||
lines.append(' "data": data.get("data", data),')
|
||||
lines.append(" }")
|
||||
lines.append("")
|
||||
|
||||
# Exception handling
|
||||
@@ -695,7 +723,7 @@ class OpenAPIToMCPGenerator:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _should_exclude_endpoint(self, path: str, operation: Dict) -> bool:
|
||||
def _should_exclude_endpoint(self, path: str, operation: dict) -> bool:
|
||||
"""
|
||||
Determine if an endpoint should be excluded from generation.
|
||||
|
||||
@@ -750,7 +778,7 @@ class OpenAPIToMCPGenerator:
|
||||
output_lines.append("")
|
||||
|
||||
# Add imports
|
||||
self.imports.add("from typing import Dict, Any, Optional")
|
||||
self.imports.add("from typing import Any, Optional")
|
||||
self.imports.add("import httpx")
|
||||
self.imports.add("from fastmcp import FastMCP")
|
||||
|
||||
@@ -848,6 +876,11 @@ class OpenAPIToMCPGenerator:
|
||||
output_lines.append("# Initialize authentication manager")
|
||||
output_lines.append("auth_manager = ProwlerAppAuth()")
|
||||
output_lines.append("")
|
||||
output_lines.append("# Initialize HTTP client")
|
||||
output_lines.append("prowler_app_client = httpx.AsyncClient(")
|
||||
output_lines.append(" timeout=30.0,")
|
||||
output_lines.append(")")
|
||||
output_lines.append("")
|
||||
|
||||
# Write tools grouped by tag
|
||||
for tag, tools in tools_by_tag.items():
|
||||
@@ -867,45 +900,6 @@ class OpenAPIToMCPGenerator:
|
||||
"""Save the generated code to a file."""
|
||||
generated_code = self.generate_tools()
|
||||
Path(output_file).write_text(generated_code)
|
||||
# print(f"Generated FastMCP server saved to: {output_file}")
|
||||
|
||||
# # Report statistics
|
||||
# paths = self.spec.get("paths", {})
|
||||
# total_endpoints = sum(
|
||||
# len(
|
||||
# [m for m in ["get", "post", "put", "patch", "delete"] if m in path_item]
|
||||
# )
|
||||
# for path_item in paths.values()
|
||||
# )
|
||||
|
||||
# # Count excluded endpoints by reason
|
||||
# excluded_count = 0
|
||||
# deprecated_count = 0
|
||||
# for path, path_item in paths.items():
|
||||
# for method in ["get", "post", "put", "patch", "delete"]:
|
||||
# if method in path_item:
|
||||
# operation = path_item[method]
|
||||
# if operation.get("deprecated", False):
|
||||
# deprecated_count += 1
|
||||
# if self._should_exclude_endpoint(path, operation):
|
||||
# excluded_count += 1
|
||||
|
||||
# generated_count = total_endpoints - excluded_count
|
||||
# print(f"Total endpoints in spec: {total_endpoints}")
|
||||
# print(f"Endpoints excluded: {excluded_count}")
|
||||
# if deprecated_count > 0:
|
||||
# print(f" - Deprecated: {deprecated_count}")
|
||||
# print(f"Endpoints generated: {generated_count}")
|
||||
|
||||
# Show exclusion rules if any
|
||||
# if self.exclude_patterns:
|
||||
# # print(f"Excluded patterns: {self.exclude_patterns}")
|
||||
# if self.exclude_operations:
|
||||
# # print(f"Excluded operations: {self.exclude_operations}")
|
||||
# if self.exclude_tags:
|
||||
# # print(f"Excluded tags: {self.exclude_tags}")
|
||||
# if self.include_only_tags:
|
||||
# # print(f"Including only tags: {self.include_only_tags}")
|
||||
|
||||
|
||||
def generate_server_file():
|
||||
|
||||
Reference in New Issue
Block a user