diff --git a/docs/getting-started/basic-usage/prowler-mcp-tools.mdx b/docs/getting-started/basic-usage/prowler-mcp-tools.mdx index 352f7d373f..bf0092a069 100644 --- a/docs/getting-started/basic-usage/prowler-mcp-tools.mdx +++ b/docs/getting-started/basic-usage/prowler-mcp-tools.mdx @@ -10,7 +10,7 @@ Complete reference guide for all tools available in the Prowler MCP Server. Tool |----------|------------|------------------------| | Prowler Hub | 10 tools | No | | Prowler Documentation | 2 tools | No | -| Prowler Cloud/App | 22 tools | Yes | +| Prowler Cloud/App | 24 tools | Yes | ## Tool Naming Convention @@ -80,6 +80,13 @@ Tools for managing finding muting, including pattern-based bulk muting (mutelist - **`prowler_app_update_mute_rule`** - Update a mute rule's name, reason, or enabled status - **`prowler_app_delete_mute_rule`** - Delete a mute rule from the system +### Compliance Management + +Tools for viewing compliance status and framework details across all cloud providers. + +- **`prowler_app_get_compliance_overview`** - Get high-level compliance status across all frameworks for a specific scan or provider, including pass/fail statistics per framework +- **`prowler_app_get_compliance_framework_state_details`** - Get detailed requirement-level breakdown for a specific compliance framework, including failed requirements and associated finding IDs + ## Prowler Hub Tools Access Prowler's security check catalog and compliance frameworks. **No authentication required.** diff --git a/mcp_server/CHANGELOG.md b/mcp_server/CHANGELOG.md index fb1545d009..e3f8c0a762 100644 --- a/mcp_server/CHANGELOG.md +++ b/mcp_server/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to the **Prowler MCP Server** are documented in this file. ## [0.3.0] (UNRELEASED) +### Added + +- Add new MCP Server tools for Prowler Compliance Framework Management [(#9568)](https://github.com/prowler-cloud/prowler/pull/9568) + ### Changed - Update API base URL environment variable to include complete path [(#9542)](https://github.com/prowler-cloud/prowler/pull/9542) diff --git a/mcp_server/README.md b/mcp_server/README.md index 3e1dec6bd0..41c5271d61 100644 --- a/mcp_server/README.md +++ b/mcp_server/README.md @@ -14,6 +14,7 @@ Full access to Prowler Cloud platform and self-managed Prowler App for: - **Scan Orchestration**: Trigger on-demand scans and schedule recurring security assessments - **Resource Inventory**: Search and view detailed information about your audited resources - **Muting Management**: Create and manage muting rules to suppress non-critical findings +- **Compliance Reporting**: View compliance status across frameworks and drill into requirement-level details ### Prowler Hub diff --git a/mcp_server/prowler_mcp_server/prowler_app/models/compliance.py b/mcp_server/prowler_mcp_server/prowler_app/models/compliance.py new file mode 100644 index 0000000000..4dfe5c4839 --- /dev/null +++ b/mcp_server/prowler_mcp_server/prowler_app/models/compliance.py @@ -0,0 +1,240 @@ +"""Pydantic models for simplified compliance responses.""" + +from typing import Any, Literal + +from prowler_mcp_server.prowler_app.models.base import MinimalSerializerMixin +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SerializerFunctionWrapHandler, + model_serializer, +) + + +class ComplianceRequirementAttribute(MinimalSerializerMixin, BaseModel): + """Requirement attributes including associated check IDs. + + Used to map requirements to the checks that validate them. + """ + + model_config = ConfigDict(frozen=True) + + id: str = Field( + description="Requirement identifier within the framework (e.g., '1.1', '2.1.1')" + ) + name: str = Field(default="", description="Human-readable name of the requirement") + description: str = Field( + default="", description="Detailed description of the requirement" + ) + check_ids: list[str] = Field( + default_factory=list, + description="List of Prowler check IDs that validate this requirement", + ) + + @classmethod + def from_api_response(cls, data: dict) -> "ComplianceRequirementAttribute": + """Transform JSON:API compliance requirement attributes response to simplified format.""" + attributes = data.get("attributes", {}) + + # Extract check_ids from the nested attributes structure + nested_attributes = attributes.get("attributes", {}) + check_ids = nested_attributes.get("check_ids", []) + + return cls( + id=attributes.get("id", data.get("id", "")), + name=attributes.get("name", ""), + description=attributes.get("description", ""), + check_ids=check_ids if check_ids else [], + ) + + +class ComplianceRequirementAttributesListResponse(BaseModel): + """Response for compliance requirement attributes list with check_ids mappings.""" + + model_config = ConfigDict(frozen=True) + + requirements: list[ComplianceRequirementAttribute] = Field( + description="List of requirements with their associated check IDs" + ) + total_count: int = Field(description="Total number of requirements") + + @classmethod + def from_api_response( + cls, response: dict + ) -> "ComplianceRequirementAttributesListResponse": + """Transform JSON:API response to simplified format.""" + data = response.get("data", []) + + requirements = [ + ComplianceRequirementAttribute.from_api_response(item) for item in data + ] + + return cls( + requirements=requirements, + total_count=len(requirements), + ) + + +class ComplianceFrameworkSummary(MinimalSerializerMixin, BaseModel): + """Simplified compliance framework overview for list operations. + + Used by get_compliance_overview() to show high-level compliance status + per framework. + """ + + model_config = ConfigDict(frozen=True) + + id: str = Field(description="Unique identifier for this compliance overview entry") + compliance_id: str = Field( + description="Compliance framework identifier (e.g., 'cis_1.5_aws', 'pci_dss_v4.0_aws')" + ) + framework: str = Field( + description="Human-readable framework name (e.g., 'CIS', 'PCI-DSS', 'HIPAA')" + ) + version: str = Field(description="Framework version (e.g., '1.5', '4.0')") + total_requirements: int = Field( + default=0, description="Total number of requirements in this framework" + ) + requirements_passed: int = Field( + default=0, description="Number of requirements that passed" + ) + requirements_failed: int = Field( + default=0, description="Number of requirements that failed" + ) + requirements_manual: int = Field( + default=0, description="Number of requirements requiring manual verification" + ) + + @property + def pass_percentage(self) -> float: + """Calculate pass percentage based on passed requirements.""" + if self.total_requirements == 0: + return 0.0 + return round((self.requirements_passed / self.total_requirements) * 100, 1) + + @property + def fail_percentage(self) -> float: + """Calculate fail percentage based on failed requirements.""" + if self.total_requirements == 0: + return 0.0 + return round((self.requirements_failed / self.total_requirements) * 100, 1) + + @model_serializer(mode="wrap") + def _serialize(self, handler: SerializerFunctionWrapHandler) -> dict[str, Any]: + """Serialize with calculated percentages included.""" + data = handler(self) + # Filter out None/empty values + data = {k: v for k, v in data.items() if v is not None and v != "" and v != []} + # Add calculated percentages + data["pass_percentage"] = self.pass_percentage + data["fail_percentage"] = self.fail_percentage + return data + + @classmethod + def from_api_response(cls, data: dict) -> "ComplianceFrameworkSummary": + """Transform JSON:API compliance overview response to simplified format.""" + attributes = data.get("attributes", {}) + + # The compliance_id field may be in attributes or use the "id" field from attributes + compliance_id = attributes.get("id", data.get("id", "")) + + return cls( + id=data["id"], + compliance_id=compliance_id, + framework=attributes.get("framework", ""), + version=attributes.get("version", ""), + total_requirements=attributes.get("total_requirements", 0), + requirements_passed=attributes.get("requirements_passed", 0), + requirements_failed=attributes.get("requirements_failed", 0), + requirements_manual=attributes.get("requirements_manual", 0), + ) + + +class ComplianceRequirement(MinimalSerializerMixin, BaseModel): + """Individual compliance requirement with its status. + + Used by get_compliance_framework_state_details() to show requirement-level breakdown. + """ + + model_config = ConfigDict(frozen=True) + + id: str = Field( + description="Requirement identifier within the framework (e.g., '1.1', '2.1.1')" + ) + description: str = Field( + description="Human-readable description of the requirement" + ) + status: Literal["FAIL", "PASS", "MANUAL"] = Field( + description="Requirement status: FAIL (not compliant), PASS (compliant), MANUAL (requires manual verification)" + ) + + @classmethod + def from_api_response(cls, data: dict) -> "ComplianceRequirement": + """Transform JSON:API compliance requirement response to simplified format.""" + attributes = data.get("attributes", {}) + + return cls( + id=attributes.get("id", data.get("id", "")), + description=attributes.get("description", ""), + status=attributes.get("status", "MANUAL"), + ) + + +class ComplianceFrameworksListResponse(BaseModel): + """Response for compliance frameworks list with aggregated statistics.""" + + model_config = ConfigDict(frozen=True) + + frameworks: list[ComplianceFrameworkSummary] = Field( + description="List of compliance frameworks with their status" + ) + total_count: int = Field(description="Total number of frameworks returned") + + @classmethod + def from_api_response(cls, response: dict) -> "ComplianceFrameworksListResponse": + """Transform JSON:API response to simplified format.""" + data = response.get("data", []) + + frameworks = [ + ComplianceFrameworkSummary.from_api_response(item) for item in data + ] + + return cls( + frameworks=frameworks, + total_count=len(frameworks), + ) + + +class ComplianceRequirementsListResponse(BaseModel): + """Response for compliance requirements list queries.""" + + model_config = ConfigDict(frozen=True) + + requirements: list[ComplianceRequirement] = Field( + description="List of requirements with their status" + ) + total_count: int = Field(description="Total number of requirements") + passed_count: int = Field(description="Number of requirements with PASS status") + failed_count: int = Field(description="Number of requirements with FAIL status") + manual_count: int = Field(description="Number of requirements with MANUAL status") + + @classmethod + def from_api_response(cls, response: dict) -> "ComplianceRequirementsListResponse": + """Transform JSON:API response to simplified format.""" + data = response.get("data", []) + + requirements = [ComplianceRequirement.from_api_response(item) for item in data] + + # Calculate counts + passed = sum(1 for r in requirements if r.status == "PASS") + failed = sum(1 for r in requirements if r.status == "FAIL") + manual = sum(1 for r in requirements if r.status == "MANUAL") + + return cls( + requirements=requirements, + total_count=len(requirements), + passed_count=passed, + failed_count=failed, + manual_count=manual, + ) diff --git a/mcp_server/prowler_mcp_server/prowler_app/tools/compliance.py b/mcp_server/prowler_mcp_server/prowler_app/tools/compliance.py new file mode 100644 index 0000000000..81f16a83e9 --- /dev/null +++ b/mcp_server/prowler_mcp_server/prowler_app/tools/compliance.py @@ -0,0 +1,409 @@ +"""Compliance framework tools for Prowler App MCP Server. + +This module provides tools for viewing compliance status and requirement details +across all cloud providers. +""" + +from typing import Any + +from prowler_mcp_server.prowler_app.models.compliance import ( + ComplianceFrameworksListResponse, + ComplianceRequirementAttributesListResponse, + ComplianceRequirementsListResponse, +) +from prowler_mcp_server.prowler_app.tools.base import BaseTool +from pydantic import Field + + +class ComplianceTools(BaseTool): + """Tools for compliance framework operations. + + Provides tools for: + - get_compliance_overview: Get high-level compliance status across all frameworks + - get_compliance_framework_state_details: Get detailed requirement-level breakdown for a specific framework + """ + + async def _get_latest_scan_id_for_provider(self, provider_id: str) -> str: + """Get the latest completed scan_id for a given provider. + + Args: + provider_id: Prowler's internal UUID for the provider + + Returns: + The scan_id of the latest completed scan for the provider. + + Raises: + ValueError: If no completed scans are found for the provider. + """ + scan_params = { + "filter[provider]": provider_id, + "filter[state]": "completed", + "sort": "-inserted_at", + "page[size]": 1, + "page[number]": 1, + } + clean_scan_params = self.api_client.build_filter_params(scan_params) + scans_response = await self.api_client.get("/scans", params=clean_scan_params) + + scans_data = scans_response.get("data", []) + if not scans_data: + raise ValueError( + f"No completed scans found for provider {provider_id}. " + "Run a scan first using prowler_app_trigger_scan." + ) + + scan_id = scans_data[0]["id"] + return scan_id + + async def get_compliance_overview( + self, + scan_id: str | None = Field( + default=None, + description="UUID of a specific scan to get compliance data for. Required if provider_id is not specified. Use `prowler_app_list_scans` to find scan IDs.", + ), + provider_id: str | None = Field( + default=None, + description="Prowler's internal UUID (v4) for a specific provider. If provided without scan_id, the tool will automatically find the latest completed scan for this provider. Use `prowler_app_search_providers` tool to find provider IDs.", + ), + ) -> dict[str, Any]: + """Get high-level compliance overview across all frameworks for a specific scan. + + This tool provides a HIGH-LEVEL OVERVIEW of compliance status across all frameworks. + Use this when you need to understand overall compliance posture before drilling into + specific framework details. + + You have two options to specify the scan context: + 1. Provide a specific scan_id to get compliance data for that scan. + 2. Provide a provider_id to get compliance data from the latest completed scan for that provider. + + The markdown report includes: + + 1. Summary Statistics: + - Total number of compliance frameworks evaluated + - Overall compliance metrics across all frameworks + + 2. Per-Framework Breakdown: + - Framework name, version, and compliance ID + - Requirements passed/failed/manual counts + - Pass percentage for quick assessment + + Workflow: + 1. Use this tool to get an overview of all compliance frameworks + 2. Use prowler_app_get_compliance_framework_state_details with a specific compliance_id to see which requirements failed + """ + if not scan_id and not provider_id: + return { + "error": "Either scan_id or provider_id must be provided. Use prowler_app_search_providers to find provider IDs or prowler_app_list_scans to find scan IDs." + } + elif scan_id and provider_id: + return { + "error": "Provide either scan_id or provider_id, not both. To get compliance data for a specific scan, use scan_id. To get data for the latest scan of a provider, use provider_id." + } + elif not scan_id and provider_id: + try: + scan_id = await self._get_latest_scan_id_for_provider(provider_id) + except ValueError as e: + return {"error": str(e)} + + params: dict[str, Any] = {"filter[scan_id]": scan_id} + + clean_params = self.api_client.build_filter_params(params) + + # Get API response + api_response = await self.api_client.get( + "/compliance-overviews", params=clean_params + ) + frameworks_response = ComplianceFrameworksListResponse.from_api_response( + api_response + ) + + # Build markdown report + frameworks = frameworks_response.frameworks + total_frameworks = frameworks_response.total_count + + if total_frameworks == 0: + return {"report": "# Compliance Overview\n\nNo compliance frameworks found"} + + # Calculate aggregate statistics + total_requirements = sum(f.total_requirements for f in frameworks) + total_passed = sum(f.requirements_passed for f in frameworks) + total_failed = sum(f.requirements_failed for f in frameworks) + total_manual = sum(f.requirements_manual for f in frameworks) + overall_pass_pct = ( + round((total_passed / total_requirements) * 100, 1) + if total_requirements > 0 + else 0 + ) + + # Build report + report_lines = [ + "# Compliance Overview", + "", + "## Summary Statistics", + f"- **Frameworks Evaluated**: {total_frameworks}", + f"- **Total Requirements**: {total_requirements:,}", + f"- **Passed**: {total_passed:,} ({overall_pass_pct}%)", + f"- **Failed**: {total_failed:,}", + f"- **Manual Review**: {total_manual:,}", + "", + "## Framework Breakdown", + "", + ] + + # Sort frameworks by fail count (most failures first) + sorted_frameworks = sorted( + frameworks, key=lambda f: f.requirements_failed, reverse=True + ) + + for fw in sorted_frameworks: + status_indicator = "PASS" if fw.requirements_failed == 0 else "FAIL" + + report_lines.append(f"### {fw.framework} {fw.version}") + report_lines.append(f"- **Compliance ID**: `{fw.compliance_id}`") + report_lines.append(f"- **Status**: {status_indicator}") + report_lines.append( + f"- **Requirements**: {fw.requirements_passed}/{fw.total_requirements} passed ({fw.pass_percentage}%)" + ) + if fw.requirements_failed > 0: + report_lines.append(f"- **Failed**: {fw.requirements_failed}") + if fw.requirements_manual > 0: + report_lines.append(f"- **Manual Review**: {fw.requirements_manual}") + report_lines.append("") + + return {"report": "\n".join(report_lines)} + + async def _get_requirement_check_ids_mapping( + self, compliance_id: str + ) -> dict[str, list[str]]: + """Get mapping of requirement IDs to their associated check IDs. + + Args: + compliance_id: The compliance framework ID. + + Returns: + Dictionary mapping requirement ID to list of check IDs. + """ + params: dict[str, Any] = { + "filter[compliance_id]": compliance_id, + "fields[compliance-requirements-attributes]": "id,attributes", + } + + clean_params = self.api_client.build_filter_params(params) + + api_response = await self.api_client.get( + "/compliance-overviews/attributes", params=clean_params + ) + attributes_response = ( + ComplianceRequirementAttributesListResponse.from_api_response(api_response) + ) + + # Build mapping: requirement_id -> [check_ids] + return {req.id: req.check_ids for req in attributes_response.requirements} + + async def _get_failed_finding_ids_for_checks( + self, + check_ids: list[str], + scan_id: str, + ) -> list[str]: + """Get all failed finding IDs for a list of check IDs. + + Args: + check_ids: List of Prowler check IDs. + scan_id: The scan ID to filter findings. + + Returns: + List of all finding IDs with FAIL status. + """ + if not check_ids: + return [] + + all_finding_ids: list[str] = [] + page_number = 1 + page_size = 100 + + while True: + # Query findings endpoint with check_id filter and FAIL status + params: dict[str, Any] = { + "filter[scan]": scan_id, + "filter[check_id__in]": ",".join(check_ids), + "filter[status]": "FAIL", + "fields[findings]": "uid", + "page[size]": page_size, + "page[number]": page_number, + } + + clean_params = self.api_client.build_filter_params(params) + + api_response = await self.api_client.get("/findings", params=clean_params) + + findings = api_response.get("data", []) + if not findings: + break + + all_finding_ids.extend([f["id"] for f in findings]) + + # Check if we've reached the last page + if len(findings) < page_size: + break + + page_number += 1 + + return all_finding_ids + + async def get_compliance_framework_state_details( + self, + compliance_id: str = Field( + description="Compliance framework ID to get details for (e.g., 'cis_1.5_aws', 'pci_dss_v4.0_aws'). You can get compliance IDs from prowler_app_get_compliance_overview or consulting Prowler Hub/Prowler Documentation that you can also find in form of tools in this MCP Server", + ), + scan_id: str | None = Field( + default=None, + description="UUID of a specific scan to get compliance data for. Required if provider_id is not specified.", + ), + provider_id: str | None = Field( + default=None, + description="Prowler's internal UUID (v4) for a specific provider. If provided without scan_id, the tool will automatically find the latest completed scan for this provider. Use `prowler_app_search_providers` tool to find provider IDs.", + ), + ) -> dict[str, Any]: + """Get detailed requirement-level breakdown for a specific compliance framework. + + IMPORTANT: This tool returns DETAILED requirement information for a single compliance framework, + focusing on FAILED requirements and their associated FAILED finding IDs. + Use this after prowler_app_get_compliance_overview to drill down into specific frameworks. + + The markdown report includes: + + 1. Framework Summary: + - Compliance ID and scan ID used + - Overall pass/fail/manual counts + + 2. Failed Requirements Breakdown: + - Each failed requirement's ID and description + - Associated failed finding IDs for each failed requirement + - Use prowler_app_get_finding_details with these finding IDs for more details and remediation guidance + + Default behavior: + - Requires either scan_id OR provider_id + - With provider_id (no scan_id): Automatically finds the latest completed scan for that provider + - With scan_id: Uses that specific scan's compliance data + - Only shows failed requirements with their associated failed finding IDs + + Workflow: + 1. Use prowler_app_get_compliance_overview to identify frameworks with failures + 2. Use this tool with the compliance_id to see failed requirements and their finding IDs + 3. Use prowler_app_get_finding_details with the finding IDs to get remediation guidance + """ + # Validate that either scan_id or provider_id is provided + if not scan_id and not provider_id: + return { + "error": "Either scan_id or provider_id must be provided. Use prowler_app_search_providers to find provider IDs or prowler_app_list_scans to find scan IDs." + } + + # Resolve provider_id to latest scan_id if needed + resolved_scan_id = scan_id + if not scan_id and provider_id: + try: + resolved_scan_id = await self._get_latest_scan_id_for_provider( + provider_id + ) + except ValueError as e: + return {"error": str(e)} + + # Build params for requirements endpoint + params: dict[str, Any] = { + "filter[scan_id]": resolved_scan_id, + "filter[compliance_id]": compliance_id, + } + + params["fields[compliance-requirements-details]"] = "id,description,status" + + clean_params = self.api_client.build_filter_params(params) + + # Get API response + api_response = await self.api_client.get( + "/compliance-overviews/requirements", params=clean_params + ) + requirements_response = ComplianceRequirementsListResponse.from_api_response( + api_response + ) + + requirements = requirements_response.requirements + + if not requirements: + return { + "report": f"# Compliance Framework Details\n\n**Compliance ID**: `{compliance_id}`\n\nNo requirements found for this compliance framework and scan combination." + } + + # Get failed requirements + failed_reqs = [r for r in requirements if r.status == "FAIL"] + + # Get requirement -> check_ids mapping from attributes endpoint + requirement_check_mapping: dict[str, list[str]] = {} + if failed_reqs: + requirement_check_mapping = await self._get_requirement_check_ids_mapping( + compliance_id + ) + + # For each failed requirement, get the failed finding IDs + failed_req_findings: dict[str, list[str]] = {} + for req in failed_reqs: + check_ids = requirement_check_mapping.get(req.id, []) + if check_ids: + finding_ids = await self._get_failed_finding_ids_for_checks( + check_ids, resolved_scan_id + ) + failed_req_findings[req.id] = finding_ids + + # Calculate counts + total_count = len(requirements) + passed_count = sum(1 for r in requirements if r.status == "PASS") + failed_count = len(failed_reqs) + manual_count = sum(1 for r in requirements if r.status == "MANUAL") + + # Build markdown report + pass_pct = ( + round((passed_count / total_count) * 100, 1) if total_count > 0 else 0 + ) + + report_lines = [ + "# Compliance Framework Details", + "", + f"**Compliance ID**: `{compliance_id}`", + f"**Scan ID**: `{resolved_scan_id}`", + "", + "## Summary", + f"- **Total Requirements**: {total_count}", + f"- **Passed**: {passed_count} ({pass_pct}%)", + f"- **Failed**: {failed_count}", + f"- **Manual Review**: {manual_count}", + "", + ] + + # Show failed requirements with their finding IDs (most actionable) + if failed_reqs: + report_lines.append("## Failed Requirements") + report_lines.append("") + for req in failed_reqs: + report_lines.append(f"### {req.id}") + report_lines.append(f"**Description**: {req.description}") + finding_ids = failed_req_findings.get(req.id, []) + if finding_ids: + report_lines.append(f"**Failed Finding IDs** ({len(finding_ids)}):") + for fid in finding_ids: + report_lines.append(f" - `{fid}`") + else: + report_lines.append("**Failed Finding IDs**: None found") + report_lines.append("") + report_lines.append( + "*Use `prowler_app_get_finding_details` with these finding IDs to get remediation guidance.*" + ) + report_lines.append("") + + if manual_count > 0: + manual_reqs = [r for r in requirements if r.status == "MANUAL"] + report_lines.append("## Requirements Requiring Manual Review") + report_lines.append("") + for req in manual_reqs: + report_lines.append(f"- **{req.id}**: {req.description}") + report_lines.append("") + + return {"report": "\n".join(report_lines)} diff --git a/mcp_server/pyproject.toml b/mcp_server/pyproject.toml index 928676865a..c269bbe18c 100644 --- a/mcp_server/pyproject.toml +++ b/mcp_server/pyproject.toml @@ -11,7 +11,7 @@ description = "MCP server for Prowler ecosystem" name = "prowler-mcp" readme = "README.md" requires-python = ">=3.12" -version = "0.1.0" +version = "0.3.0" [project.scripts] generate-prowler-app-mcp-server = "prowler_mcp_server.prowler_app.utils.server_generator:generate_server_file" diff --git a/mcp_server/uv.lock b/mcp_server/uv.lock index 9b401fb67f..8781695d7e 100644 --- a/mcp_server/uv.lock +++ b/mcp_server/uv.lock @@ -603,7 +603,7 @@ wheels = [ [[package]] name = "prowler-mcp" -version = "0.1.0" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "fastmcp" },