feat(sdk): integrate universal compliance into CLI pipeline (#10301)

This commit is contained in:
Pedro Martín
2026-04-30 13:49:00 +02:00
committed by GitHub
parent 4608e45c8a
commit 578186aa40
17 changed files with 2634 additions and 101 deletions
+2
View File
@@ -7,12 +7,14 @@ All notable changes to the **Prowler SDK** are documented in this file.
### 🚀 Added
- `bedrock_guardrails_configured` check for AWS provider [(#10844)](https://github.com/prowler-cloud/prowler/pull/10844)
- Universal compliance pipeline integrated into the CLI: `--list-compliance` and `--list-compliance-requirements` show universal frameworks, and CSV plus OCSF outputs are generated for any framework declaring a `TableConfig` [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301)
### 🔄 Changed
- `route53_dangling_ip_subdomain_takeover` now also flags `CNAME` records pointing to S3 website endpoints whose buckets are missing from the account [(#10920)](https://github.com/prowler-cloud/prowler/pull/10920)
- Azure Network Watcher flow log checks now require workspace-backed Traffic Analytics for `network_flow_log_captured_sent` and align metadata with VNet-compatible flow log guidance [(#10645)](https://github.com/prowler-cloud/prowler/pull/10645)
- Azure compliance entries for legacy Network Watcher flow log controls now use retirement-aware guidance and point new deployments to VNet flow logs
- `display_compliance_table` dispatch switched from substring `in` checks to `startswith` to prevent false matches between similarly named frameworks (e.g. `cisa` vs `cis`) [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301)
### 🐞 Fixed
+41 -13
View File
@@ -45,7 +45,10 @@ from prowler.lib.check.check import (
)
from prowler.lib.check.checks_loader import load_checks_to_execute
from prowler.lib.check.compliance import update_checks_metadata_with_compliance
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.check.compliance_models import (
Compliance,
get_bulk_compliance_frameworks_universal,
)
from prowler.lib.check.custom_checks_metadata import (
parse_custom_checks_metadata_file,
update_checks_metadata,
@@ -75,7 +78,10 @@ from prowler.lib.outputs.compliance.cis.cis_oraclecloud import OracleCloudCIS
from prowler.lib.outputs.compliance.cisa_scuba.cisa_scuba_googleworkspace import (
GoogleWorkspaceCISASCuBA,
)
from prowler.lib.outputs.compliance.compliance import display_compliance_table
from prowler.lib.outputs.compliance.compliance import (
display_compliance_table,
process_universal_compliance_frameworks,
)
from prowler.lib.outputs.compliance.csa.csa_alibabacloud import AlibabaCloudCSA
from prowler.lib.outputs.compliance.csa.csa_aws import AWSCSA
from prowler.lib.outputs.compliance.csa.csa_azure import AzureCSA
@@ -235,6 +241,8 @@ def prowler():
# Load compliance frameworks
logger.debug("Loading compliance frameworks from .json files")
universal_frameworks = {}
# Skip compliance frameworks for external-tool providers
if provider not in EXTERNAL_TOOL_PROVIDERS:
bulk_compliance_frameworks = Compliance.get_bulk(provider)
@@ -242,6 +250,8 @@ def prowler():
bulk_checks_metadata = update_checks_metadata_with_compliance(
bulk_compliance_frameworks, bulk_checks_metadata
)
# Load universal compliance frameworks for new rendering pipeline
universal_frameworks = get_bulk_compliance_frameworks_universal(provider)
# Update checks metadata if the --custom-checks-metadata-file is present
custom_checks_metadata = None
@@ -254,12 +264,12 @@ def prowler():
)
if args.list_compliance:
print_compliance_frameworks(bulk_compliance_frameworks)
all_frameworks = {**bulk_compliance_frameworks, **universal_frameworks}
print_compliance_frameworks(all_frameworks)
sys.exit()
if args.list_compliance_requirements:
print_compliance_requirements(
bulk_compliance_frameworks, args.list_compliance_requirements
)
all_frameworks = {**bulk_compliance_frameworks, **universal_frameworks}
print_compliance_requirements(all_frameworks, args.list_compliance_requirements)
sys.exit()
# Load checks to execute
@@ -276,6 +286,7 @@ def prowler():
provider=provider,
list_checks=getattr(args, "list_checks", False)
or getattr(args, "list_checks_json", False),
universal_frameworks=universal_frameworks,
)
# if --list-checks-json, dump a json file and exit
@@ -624,15 +635,29 @@ def prowler():
)
# Compliance Frameworks
# Source the framework listing from `bulk_compliance_frameworks.keys()`
# so it is by construction a subset of what the bulk loader can resolve.
# `get_available_compliance_frameworks(provider)` also discovers top-level
# multi-provider universal JSONs (e.g. `prowler/compliance/csa_ccm_4.0.json`)
# which `Compliance.get_bulk(provider)` does not load, and which the legacy
# output handlers below cannot consume — using it as the source produced
# Source the framework listing from the union of `bulk_compliance_frameworks`
# and `universal_frameworks` so universal-only frameworks (e.g.
# `prowler/compliance/csa_ccm_4.0.json`) — which `Compliance.get_bulk(provider)`
# does not load — still reach `process_universal_compliance_frameworks` below.
# The provider-specific block subtracts the names handled by the universal
# processor so the legacy per-provider handlers only see frameworks that the
# bulk loader actually resolved.
input_compliance_frameworks = set(output_options.output_modes).intersection(
bulk_compliance_frameworks.keys()
set(bulk_compliance_frameworks.keys()) | set(universal_frameworks.keys())
)
# ── Universal compliance frameworks (provider-agnostic) ──
universal_processed = process_universal_compliance_frameworks(
input_compliance_frameworks=input_compliance_frameworks,
universal_frameworks=universal_frameworks,
finding_outputs=finding_outputs,
output_directory=output_options.output_directory,
output_filename=output_options.output_filename,
provider=provider,
generated_outputs=generated_outputs,
)
input_compliance_frameworks -= universal_processed
if provider == "aws":
for compliance_name in input_compliance_frameworks:
if compliance_name.startswith("cis_"):
@@ -1402,6 +1427,9 @@ def prowler():
output_options.output_filename,
output_options.output_directory,
compliance_overview,
universal_frameworks=universal_frameworks,
provider=provider,
output_formats=args.output_formats,
)
if compliance_overview:
print(
+5 -3
View File
@@ -87,8 +87,8 @@ def get_available_compliance_frameworks(provider=None):
providers = [p.value for p in Provider]
if provider:
providers = [provider]
for provider in providers:
compliance_dir = f"{actual_directory}/../compliance/{provider}"
for current_provider in providers:
compliance_dir = f"{actual_directory}/../compliance/{current_provider}"
if not os.path.isdir(compliance_dir):
continue
with os.scandir(compliance_dir) as files:
@@ -97,7 +97,9 @@ def get_available_compliance_frameworks(provider=None):
available_compliance_frameworks.append(
file.name.removesuffix(".json")
)
# Also scan top-level compliance/ for multi-provider JSONs
# Also scan top-level compliance/ for multi-provider (universal) JSONs.
# When a specific provider was requested, only include the framework if it
# declares support for that provider; otherwise include all universal frameworks.
compliance_root = f"{actual_directory}/../compliance"
if os.path.isdir(compliance_root):
with os.scandir(compliance_root) as files:
+30 -7
View File
@@ -299,12 +299,22 @@ def print_compliance_frameworks(
def print_compliance_requirements(
bulk_compliance_frameworks: dict, compliance_frameworks: list
):
from prowler.lib.check.compliance_models import ComplianceFramework
for compliance_framework in compliance_frameworks:
for key in bulk_compliance_frameworks.keys():
framework = bulk_compliance_frameworks[key].Framework
provider = bulk_compliance_frameworks[key].Provider
version = bulk_compliance_frameworks[key].Version
requirements = bulk_compliance_frameworks[key].Requirements
entry = bulk_compliance_frameworks[key]
is_universal = isinstance(entry, ComplianceFramework)
if is_universal:
framework = entry.framework
provider = entry.provider or "Multi-provider"
version = entry.version
requirements = entry.requirements
else:
framework = entry.Framework
provider = entry.Provider or "Multi-provider"
version = entry.Version
requirements = entry.Requirements
# We can list the compliance requirements for a given framework using the
# bulk_compliance_frameworks keys since they are the compliance specification file name
if compliance_framework == key:
@@ -313,10 +323,23 @@ def print_compliance_requirements(
)
for requirement in requirements:
checks = ""
for check in requirement.Checks:
checks += f" {Fore.YELLOW}\t\t{check}\n{Style.RESET_ALL}"
if is_universal:
req_checks = requirement.checks
req_id = requirement.id
req_description = requirement.description
else:
req_checks = requirement.Checks
req_id = requirement.Id
req_description = requirement.Description
if isinstance(req_checks, dict):
for prov, check_list in req_checks.items():
for check in check_list:
checks += f" {Fore.YELLOW}\t\t[{prov}] {check}\n{Style.RESET_ALL}"
else:
for check in req_checks:
checks += f" {Fore.YELLOW}\t\t{check}\n{Style.RESET_ALL}"
print(
f"Requirement Id: {Fore.MAGENTA}{requirement.Id}{Style.RESET_ALL}\n\t- Description: {requirement.Description}\n\t- Checks:\n{checks}"
f"Requirement Id: {Fore.MAGENTA}{req_id}{Style.RESET_ALL}\n\t- Description: {req_description}\n\t- Checks:\n{checks}"
)
+15 -5
View File
@@ -22,6 +22,7 @@ def load_checks_to_execute(
categories: set = None,
resource_groups: set = None,
list_checks: bool = False,
universal_frameworks: dict = None,
) -> set:
"""Generate the list of checks to execute based on the cloud provider and the input arguments given"""
try:
@@ -155,12 +156,21 @@ def load_checks_to_execute(
if not bulk_compliance_frameworks:
bulk_compliance_frameworks = Compliance.get_bulk(provider=provider)
for compliance_framework in compliance_frameworks:
checks_to_execute.update(
CheckMetadata.list(
bulk_compliance_frameworks=bulk_compliance_frameworks,
compliance_framework=compliance_framework,
# Try universal frameworks first (snake_case dict-keyed checks)
if (
universal_frameworks
and compliance_framework in universal_frameworks
):
fw = universal_frameworks[compliance_framework]
for req in fw.requirements:
checks_to_execute.update(req.checks.get(provider.lower(), []))
elif compliance_framework in bulk_compliance_frameworks:
checks_to_execute.update(
CheckMetadata.list(
bulk_compliance_frameworks=bulk_compliance_frameworks,
compliance_framework=compliance_framework,
)
)
)
# Handle if there are categories passed using --categories
elif categories:
+129 -62
View File
@@ -1,10 +1,12 @@
import sys
from prowler.lib.check.models import Check_Report
from prowler.lib.logger import logger
from prowler.lib.outputs.compliance.c5.c5 import get_c5_table
from prowler.lib.outputs.compliance.ccc.ccc import get_ccc_table
from prowler.lib.outputs.compliance.cis.cis import get_cis_table
from prowler.lib.outputs.compliance.compliance_check import ( # noqa: F401 - re-export for backward compatibility
get_check_compliance,
)
from prowler.lib.outputs.compliance.csa.csa import get_csa_table
from prowler.lib.outputs.compliance.ens.ens import get_ens_table
from prowler.lib.outputs.compliance.generic.generic_table import (
@@ -17,6 +19,94 @@ from prowler.lib.outputs.compliance.mitre_attack.mitre_attack import (
from prowler.lib.outputs.compliance.prowler_threatscore.prowler_threatscore import (
get_prowler_threatscore_table,
)
from prowler.lib.outputs.compliance.universal.universal_table import get_universal_table
def process_universal_compliance_frameworks(
input_compliance_frameworks: set,
universal_frameworks: dict,
finding_outputs: list,
output_directory: str,
output_filename: str,
provider: str,
generated_outputs: dict,
) -> set:
"""Process universal compliance frameworks, generating CSV and OCSF outputs.
For each framework in *input_compliance_frameworks* that exists in
*universal_frameworks* and has an outputs.table_config, this function
creates both a CSV (UniversalComplianceOutput) and an OCSF JSON
(OCSFComplianceOutput) file. OCSF is always generated regardless of
the user's ``--output-formats`` flag.
The function is idempotent: it tracks already-created writers via
``generated_outputs["compliance"]`` keyed by ``file_path``. If invoked
again for the same framework (e.g. once per streaming batch), it
reuses the existing writer instead of recreating it. This guarantees
one output writer per framework for the whole execution and keeps
the OCSF JSON array valid across multiple calls.
Returns the set of framework names that were processed so the caller
can remove them before entering the legacy per-provider output loop.
"""
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
OCSFComplianceOutput,
)
from prowler.lib.outputs.compliance.universal.universal_output import (
UniversalComplianceOutput,
)
existing_writers = {
getattr(out, "file_path", None): out
for out in generated_outputs.get("compliance", [])
if isinstance(out, (UniversalComplianceOutput, OCSFComplianceOutput))
}
processed = set()
for compliance_name in input_compliance_frameworks:
if not (
compliance_name in universal_frameworks
and universal_frameworks[compliance_name].outputs
and universal_frameworks[compliance_name].outputs.table_config
):
continue
fw = universal_frameworks[compliance_name]
# CSV output
csv_path = (
f"{output_directory}/compliance/" f"{output_filename}_{compliance_name}.csv"
)
if csv_path not in existing_writers:
output = UniversalComplianceOutput(
findings=finding_outputs,
framework=fw,
file_path=csv_path,
provider=provider,
)
generated_outputs["compliance"].append(output)
existing_writers[csv_path] = output
output.batch_write_data_to_file()
# OCSF output (always generated for universal frameworks)
ocsf_path = (
f"{output_directory}/compliance/"
f"{output_filename}_{compliance_name}.ocsf.json"
)
if ocsf_path not in existing_writers:
ocsf_output = OCSFComplianceOutput(
findings=finding_outputs,
framework=fw,
file_path=ocsf_path,
provider=provider,
)
generated_outputs["compliance"].append(ocsf_output)
existing_writers[ocsf_path] = ocsf_output
ocsf_output.batch_write_data_to_file()
processed.add(compliance_name)
return processed
def display_compliance_table(
@@ -26,6 +116,9 @@ def display_compliance_table(
output_filename: str,
output_directory: str,
compliance_overview: bool,
universal_frameworks: dict = None,
provider: str = None,
output_formats: list = None,
) -> None:
"""
display_compliance_table generates the compliance table for the given compliance framework.
@@ -37,6 +130,9 @@ def display_compliance_table(
output_filename (str): The output filename
output_directory (str): The output directory
compliance_overview (bool): The compliance
universal_frameworks (dict): Optional universal ComplianceFramework objects
provider (str): The current provider (e.g. "aws") for multi-provider filtering
output_formats (list): The output formats to generate
Returns:
None
@@ -45,16 +141,24 @@ def display_compliance_table(
findings = [f for f in findings if f.check_metadata.CheckID in bulk_checks_metadata]
try:
if "ens_" in compliance_framework:
get_ens_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
)
elif "cis_" in compliance_framework:
# Universal path: if the framework has TableConfig, use the universal renderer
if universal_frameworks and compliance_framework in universal_frameworks:
fw = universal_frameworks[compliance_framework]
if fw.outputs and fw.outputs.table_config:
get_universal_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
framework=fw,
provider=provider,
output_formats=output_formats,
)
return
if compliance_framework.startswith("cis_"):
get_cis_table(
findings,
bulk_checks_metadata,
@@ -63,7 +167,16 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "mitre_attack" in compliance_framework:
elif compliance_framework.startswith("ens_"):
get_ens_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
)
elif compliance_framework.startswith("mitre_attack"):
get_mitre_attack_table(
findings,
bulk_checks_metadata,
@@ -72,7 +185,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "kisa_isms_" in compliance_framework:
elif compliance_framework.startswith("kisa"):
get_kisa_ismsp_table(
findings,
bulk_checks_metadata,
@@ -81,7 +194,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "threatscore_" in compliance_framework:
elif compliance_framework.startswith("prowler_threatscore_"):
get_prowler_threatscore_table(
findings,
bulk_checks_metadata,
@@ -90,7 +203,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "csa_ccm_" in compliance_framework:
elif compliance_framework.startswith("csa_ccm_"):
get_csa_table(
findings,
bulk_checks_metadata,
@@ -99,7 +212,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "c5_" in compliance_framework:
elif compliance_framework.startswith("c5_"):
get_c5_table(
findings,
bulk_checks_metadata,
@@ -131,49 +244,3 @@ def display_compliance_table(
f"{error.__class__.__name__}:{error.__traceback__.tb_lineno} -- {error}"
)
sys.exit(1)
# TODO: this should be in the Check class
def get_check_compliance(
finding: Check_Report, provider_type: str, bulk_checks_metadata: dict
) -> dict:
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
Example:
{
"CIS-1.4": ["2.1.3"],
"CIS-1.5": ["2.1.3"],
}
Args:
finding (Any): The Check_Report finding
provider_type (str): The provider type
bulk_checks_metadata (dict): The bulk checks metadata
Returns:
dict: The compliance framework as key and the requirements where the finding's check is present.
"""
try:
check_compliance = {}
# We have to retrieve all the check's compliance requirements
if finding.check_metadata.CheckID in bulk_checks_metadata:
for compliance in bulk_checks_metadata[
finding.check_metadata.CheckID
].Compliance:
compliance_fw = compliance.Framework
if compliance.Version:
compliance_fw = f"{compliance_fw}-{compliance.Version}"
# compliance.Provider == "Azure" or "Kubernetes"
# provider_type == "azure" or "kubernetes"
if compliance.Provider.upper() == provider_type.upper():
if compliance_fw not in check_compliance:
check_compliance[compliance_fw] = []
for requirement in compliance.Requirements:
check_compliance[compliance_fw].append(requirement.Id)
return check_compliance
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
return {}
@@ -0,0 +1,48 @@
from prowler.lib.check.models import Check_Report
from prowler.lib.logger import logger
# TODO: this should be in the Check class
def get_check_compliance(
finding: Check_Report, provider_type: str, bulk_checks_metadata: dict
) -> dict:
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
Example:
{
"CIS-1.4": ["2.1.3"],
"CIS-1.5": ["2.1.3"],
}
Args:
finding (Any): The Check_Report finding
provider_type (str): The provider type
bulk_checks_metadata (dict): The bulk checks metadata
Returns:
dict: The compliance framework as key and the requirements where the finding's check is present.
"""
try:
check_compliance = {}
# We have to retrieve all the check's compliance requirements
if finding.check_metadata.CheckID in bulk_checks_metadata:
for compliance in bulk_checks_metadata[
finding.check_metadata.CheckID
].Compliance:
compliance_fw = compliance.Framework
if compliance.Version:
compliance_fw = f"{compliance_fw}-{compliance.Version}"
# compliance.Provider == "Azure" or "Kubernetes"
# provider_type == "azure" or "kubernetes"
if compliance.Provider.upper() == provider_type.upper():
if compliance_fw not in check_compliance:
check_compliance[compliance_fw] = []
for requirement in compliance.Requirements:
check_compliance[compliance_fw].append(requirement.Id)
return check_compliance
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
return {}
@@ -1,6 +1,7 @@
import json
import os
from datetime import datetime
from typing import List
from typing import TYPE_CHECKING, List
from py_ocsf_models.events.base_event import SeverityID
from py_ocsf_models.events.base_event import StatusID as EventStatusID
@@ -20,11 +21,12 @@ from py_ocsf_models.objects.resource_details import ResourceDetails
from prowler.config.config import prowler_version
from prowler.lib.check.compliance_models import ComplianceFramework
from prowler.lib.logger import logger
from prowler.lib.outputs.finding import Finding
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.lib.outputs.utils import unroll_dict_to_list
from prowler.lib.utils.utils import open_file
if TYPE_CHECKING:
from prowler.lib.outputs.finding import Finding
PROWLER_TO_COMPLIANCE_STATUS = {
"PASS": ComplianceStatusID.Pass,
"FAIL": ComplianceStatusID.Fail,
@@ -32,6 +34,40 @@ PROWLER_TO_COMPLIANCE_STATUS = {
}
def _sanitize_resource_data(resource_details, resource_metadata) -> dict:
"""Ensure resource data is JSON-serializable.
Service resource_metadata may carry non-serializable objects (e.g. raw
Pydantic models or service classes such as ``Trail`` / ``LifecyclePolicy``).
Convert them to plain dicts and roundtrip through JSON so the resulting
ComplianceFinding can be serialized without errors.
"""
def _make_serializable(obj):
if hasattr(obj, "model_dump") and callable(obj.model_dump):
return _make_serializable(obj.model_dump())
if hasattr(obj, "dict") and callable(obj.dict):
return _make_serializable(obj.dict())
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_make_serializable(v) for v in obj]
return obj
try:
converted = _make_serializable(resource_metadata)
sanitized_metadata = json.loads(json.dumps(converted, default=str))
except (TypeError, ValueError, RecursionError) as error:
logger.warning(
f"Failed to serialize resource metadata, defaulting to empty: {error}"
)
sanitized_metadata = {}
return {
"details": resource_details,
"metadata": sanitized_metadata,
}
def _to_snake_case(name: str) -> str:
"""Convert a PascalCase or camelCase string to snake_case."""
import re
@@ -108,7 +144,7 @@ class OCSFComplianceOutput:
def _transform(
self,
findings: List[Finding],
findings: List["Finding"],
framework: ComplianceFramework,
compliance_name: str,
) -> None:
@@ -177,7 +213,7 @@ class OCSFComplianceOutput:
def _build_compliance_finding(
self,
finding: Finding,
finding: "Finding",
framework: ComplianceFramework,
requirement,
compliance_name: str,
@@ -195,7 +231,9 @@ class OCSFComplianceOutput:
finding.metadata.Severity.capitalize(),
SeverityID.Unknown,
)
event_status = OCSF.get_finding_status_id(finding.muted)
event_status = (
EventStatusID.Suppressed if finding.muted else EventStatusID.New
)
time_value = (
int(finding.timestamp.timestamp())
@@ -268,10 +306,10 @@ class OCSFComplianceOutput:
if finding.provider == "kubernetes"
else None
),
data={
"details": finding.resource_details,
"metadata": finding.resource_metadata,
},
data=_sanitize_resource_data(
finding.resource_details,
finding.resource_metadata,
),
)
],
severity_id=finding_severity.value,
@@ -0,0 +1,294 @@
from csv import DictWriter
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from pydantic.v1 import create_model
from prowler.config.config import timestamp
from prowler.lib.check.compliance_models import ComplianceFramework
from prowler.lib.logger import logger
from prowler.lib.utils.utils import open_file
if TYPE_CHECKING:
from prowler.lib.outputs.finding import Finding
PROVIDER_HEADER_MAP = {
"aws": ("AccountId", "account_uid", "Region", "region"),
"azure": ("SubscriptionId", "account_uid", "Location", "region"),
"gcp": ("ProjectId", "account_uid", "Location", "region"),
"kubernetes": ("Context", "account_name", "Namespace", "region"),
"m365": ("TenantId", "account_uid", "Location", "region"),
"github": ("Account_Name", "account_name", "Account_Id", "account_uid"),
"oraclecloud": ("TenancyId", "account_uid", "Region", "region"),
"alibabacloud": ("AccountId", "account_uid", "Region", "region"),
"nhn": ("AccountId", "account_uid", "Region", "region"),
}
_DEFAULT_HEADERS = ("AccountId", "account_uid", "Region", "region")
class UniversalComplianceOutput:
"""Universal compliance CSV output driven by ComplianceFramework metadata.
Dynamically builds a Pydantic row model from attributes_metadata so that
CSV columns match the framework's declared attribute fields.
"""
def __init__(
self,
findings: list,
framework: ComplianceFramework,
file_path: str = None,
from_cli: bool = True,
provider: str = None,
) -> None:
self._data = []
self._file_descriptor = None
self.file_path = file_path
self._from_cli = from_cli
self._provider = provider
self.close_file = False
if file_path:
path_obj = Path(file_path)
self._file_extension = path_obj.suffix if path_obj.suffix else ""
if findings:
self._row_model = self._build_row_model(framework)
compliance_name = (
framework.framework + "-" + framework.version
if framework.version
else framework.framework
)
self._transform(findings, framework, compliance_name)
if not self._file_descriptor and file_path:
self._create_file_descriptor(file_path)
@property
def data(self):
return self._data
def _build_row_model(self, framework: ComplianceFramework):
"""Build a dynamic Pydantic model from attributes_metadata."""
acct_header, acct_field, loc_header, loc_field = PROVIDER_HEADER_MAP.get(
(self._provider or "").lower(), _DEFAULT_HEADERS
)
self._acct_header = acct_header
self._acct_field = acct_field
self._loc_header = loc_header
self._loc_field = loc_field
# Base fields present in every compliance CSV
fields = {
"Provider": (str, ...),
"Description": (str, ...),
acct_header: (str, ...),
loc_header: (str, ...),
"AssessmentDate": (str, ...),
"Requirements_Id": (str, ...),
"Requirements_Description": (str, ...),
}
# Dynamic attribute columns from metadata
if framework.attributes_metadata:
for attr_meta in framework.attributes_metadata:
if not attr_meta.output_formats.csv:
continue
field_name = f"Requirements_Attributes_{attr_meta.key}"
# Map type strings to Python types
type_map = {
"str": Optional[str],
"int": Optional[int],
"float": Optional[float],
"bool": Optional[bool],
"list_str": Optional[str], # Serialized as joined string
"list_dict": Optional[str], # Serialized as string
}
py_type = type_map.get(attr_meta.type, Optional[str])
fields[field_name] = (py_type, None)
# Check if any requirement has MITRE fields
has_mitre = any(req.tactics for req in framework.requirements if req.tactics)
if has_mitre:
fields["Requirements_Tactics"] = (Optional[str], None)
fields["Requirements_SubTechniques"] = (Optional[str], None)
fields["Requirements_Platforms"] = (Optional[str], None)
fields["Requirements_TechniqueURL"] = (Optional[str], None)
# Trailing fields
fields["Status"] = (str, ...)
fields["StatusExtended"] = (str, ...)
fields["ResourceId"] = (str, ...)
fields["ResourceName"] = (str, ...)
fields["CheckId"] = (str, ...)
fields["Muted"] = (bool, ...)
fields["Framework"] = (str, ...)
fields["Name"] = (str, ...)
return create_model("UniversalComplianceRow", **fields)
def _serialize_attr_value(self, value):
"""Serialize attribute values for CSV."""
if isinstance(value, list):
if value and isinstance(value[0], dict):
return str(value)
return " | ".join(str(v) for v in value)
return value
def _build_row(self, finding, framework, requirement, is_manual=False):
"""Build a single row dict for a finding + requirement combination."""
row = {
"Provider": (
finding.provider
if not is_manual
else (framework.provider or self._provider or "").lower()
),
"Description": framework.description,
self._acct_header: (
getattr(finding, self._acct_field, "") if not is_manual else ""
),
self._loc_header: (
getattr(finding, self._loc_field, "") if not is_manual else ""
),
"AssessmentDate": str(timestamp),
"Requirements_Id": requirement.id,
"Requirements_Description": requirement.description,
}
# Add dynamic attribute columns
if framework.attributes_metadata:
for attr_meta in framework.attributes_metadata:
if not attr_meta.output_formats.csv:
continue
field_name = f"Requirements_Attributes_{attr_meta.key}"
raw_val = requirement.attributes.get(attr_meta.key)
row[field_name] = (
self._serialize_attr_value(raw_val) if raw_val is not None else None
)
# MITRE fields
if requirement.tactics:
row["Requirements_Tactics"] = (
" | ".join(requirement.tactics) if requirement.tactics else None
)
row["Requirements_SubTechniques"] = (
" | ".join(requirement.sub_techniques)
if requirement.sub_techniques
else None
)
row["Requirements_Platforms"] = (
" | ".join(requirement.platforms) if requirement.platforms else None
)
row["Requirements_TechniqueURL"] = requirement.technique_url
row["Status"] = finding.status if not is_manual else "MANUAL"
row["StatusExtended"] = (
finding.status_extended if not is_manual else "Manual check"
)
row["ResourceId"] = finding.resource_uid if not is_manual else "manual_check"
row["ResourceName"] = finding.resource_name if not is_manual else "Manual check"
row["CheckId"] = finding.check_id if not is_manual else "manual"
row["Muted"] = finding.muted if not is_manual else False
row["Framework"] = framework.framework
row["Name"] = framework.name
return row
def _transform(
self,
findings: list["Finding"],
framework: ComplianceFramework,
compliance_name: str,
) -> None:
"""Transform findings into universal compliance CSV rows."""
# Build check -> requirements map (filtered by provider for dict checks)
check_req_map = {}
for req in framework.requirements:
checks = req.checks
if self._provider:
all_checks = checks.get(self._provider.lower(), [])
else:
all_checks = []
for check_list in checks.values():
all_checks.extend(check_list)
for check_id in all_checks:
if check_id not in check_req_map:
check_req_map[check_id] = []
check_req_map[check_id].append(req)
# Process findings using the provider-filtered check_req_map.
# This ensures that for multi-provider dict checks, only the checks
# belonging to the current provider produce output rows.
for finding in findings:
check_id = finding.check_id
if check_id in check_req_map:
for req in check_req_map[check_id]:
row = self._build_row(finding, framework, req)
try:
self._data.append(self._row_model(**row))
except Exception as e:
logger.debug(f"Skipping row for {req.id}: {e}")
# Manual requirements (no checks or empty dict)
for req in framework.requirements:
checks = req.checks
if self._provider:
has_checks = bool(checks.get(self._provider.lower(), []))
else:
has_checks = any(checks.values())
if not has_checks:
# Use a dummy finding-like namespace for manual rows
row = self._build_row(
_ManualFindingStub(), framework, req, is_manual=True
)
try:
self._data.append(self._row_model(**row))
except Exception as e:
logger.debug(f"Skipping manual row for {req.id}: {e}")
def _create_file_descriptor(self, file_path: str) -> None:
try:
self._file_descriptor = open_file(file_path, "a")
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def batch_write_data_to_file(self) -> None:
"""Write findings data to CSV."""
try:
if (
getattr(self, "_file_descriptor", None)
and not self._file_descriptor.closed
and self._data
):
csv_writer = DictWriter(
self._file_descriptor,
fieldnames=[field.upper() for field in self._data[0].dict().keys()],
delimiter=";",
)
if self._file_descriptor.tell() == 0:
csv_writer.writeheader()
for row in self._data:
csv_writer.writerow({k.upper(): v for k, v in row.dict().items()})
if self.close_file or self._from_cli:
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class _ManualFindingStub:
"""Minimal stub to satisfy _build_row for manual requirements."""
provider = ""
account_uid = ""
account_name = ""
region = ""
status = "MANUAL"
status_extended = "Manual check"
resource_uid = "manual_check"
resource_name = "Manual check"
check_id = "manual"
muted = False
+1 -1
View File
@@ -15,7 +15,7 @@ from prowler.lib.check.models import (
)
from prowler.lib.logger import logger
from prowler.lib.outputs.common import Status, fill_common_finding_data
from prowler.lib.outputs.compliance.compliance import get_check_compliance
from prowler.lib.outputs.compliance.compliance_check import get_check_compliance
from prowler.lib.outputs.utils import unroll_tags
from prowler.lib.utils.utils import dict_to_lowercase, get_nested_attribute
from prowler.providers.common.provider import Provider
+27
View File
@@ -436,6 +436,33 @@ class Test_Config:
assert "csa_ccm_4.0" in aws_frameworks
assert "csa_ccm_4.0" not in kubernetes_frameworks
def test_get_available_compliance_frameworks_no_provider_includes_universals(self):
"""Regression test for the variable shadowing bug.
Previously, the inner ``for provider in providers`` loop shadowed
the outer ``provider`` parameter. When called without a provider,
the post-loop ``if provider:`` branch wrongly applied
``framework.supports_provider(<last provider iterated>)`` and
excluded universal frameworks from the result.
Result: the parser-level ``available_compliance_frameworks``
constant was missing universal frameworks like ``csa_ccm_4.0``,
which made ``--compliance csa_ccm_4.0`` reject the choice.
"""
all_frameworks = get_available_compliance_frameworks()
assert "csa_ccm_4.0" in all_frameworks
def test_get_available_compliance_frameworks_does_not_mutate_provider_param(self):
"""Calling with a specific provider must not affect a subsequent
call without provider. Validates that the loop variable rename
prevents leaking state between calls."""
# Force an iteration over multiple providers first
get_available_compliance_frameworks("kubernetes")
# Then a no-provider call must still include universals supported
# by ANY provider (not filtered by some leaked value)
all_frameworks = get_available_compliance_frameworks()
assert "csa_ccm_4.0" in all_frameworks
def test_load_and_validate_config_file_aws(self):
path = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
config_test_file = f"{path}/fixtures/config.yaml"
+174
View File
@@ -675,3 +675,177 @@ class TestCheckLoader:
)
assert CLOUDTRAIL_THREAT_DETECTION_ENUMERATION_NAME not in result
assert S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME in result
def test_load_checks_to_execute_universal_framework_takes_precedence(self):
"""When ``--compliance <fw>`` matches a universal framework, the
loader must source checks from ``universal_frameworks[fw].requirements[*]
.checks[provider]`` and NOT fall through to ``bulk_compliance_frameworks``.
This is the path added by PR #10301 in checks_loader.py.
"""
from prowler.lib.check.compliance_models import (
ComplianceFramework,
UniversalComplianceRequirement,
)
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
universal_framework = ComplianceFramework(
framework="csa_ccm",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="A&A-01",
description="Audit & Assurance",
attributes={},
checks={"aws": [S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME]},
),
],
)
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks={}, # legacy empty
compliance_frameworks=["csa_ccm_4.0"],
provider=self.provider,
universal_frameworks={"csa_ccm_4.0": universal_framework},
)
assert result == {S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME}
def test_load_checks_to_execute_universal_filters_by_provider(self):
"""A universal requirement may declare checks for several
providers; the loader must only return those for the active
provider key (lowercased)."""
from prowler.lib.check.compliance_models import (
ComplianceFramework,
UniversalComplianceRequirement,
)
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
# The same requirement maps a different check per provider.
# Only the AWS one must be returned for provider="aws".
universal_framework = ComplianceFramework(
framework="csa_ccm",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="A&A-02",
description="Multi-provider req",
attributes={},
checks={
"aws": [S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME],
"azure": ["azure_only_check"],
"gcp": ["gcp_only_check"],
},
),
],
)
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks={},
compliance_frameworks=["csa_ccm_4.0"],
provider=self.provider, # "aws"
universal_frameworks={"csa_ccm_4.0": universal_framework},
)
assert S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME in result
assert "azure_only_check" not in result
assert "gcp_only_check" not in result
def test_load_checks_to_execute_universal_no_match_falls_back_to_legacy(self):
"""If the requested compliance framework is not present in
``universal_frameworks``, the loader must fall back to the
legacy ``bulk_compliance_frameworks`` lookup."""
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
bulk_compliance_frameworks = {
"soc2_aws": Compliance(
Framework="SOC2",
Name="SOC2",
Provider="aws",
Version="2.0",
Description="x",
Requirements=[
Compliance_Requirement(
Checks=[S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME],
Id="",
Description="",
Attributes=[],
)
],
),
}
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks=bulk_compliance_frameworks,
compliance_frameworks=["soc2_aws"],
provider=self.provider,
universal_frameworks={"some_other_universal_fw": object()},
)
assert result == {S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME}
def test_load_checks_to_execute_universal_unknown_provider_returns_empty(self):
"""If the universal requirement has no checks for the active
provider, no checks are picked up for that requirement."""
from prowler.lib.check.compliance_models import (
ComplianceFramework,
UniversalComplianceRequirement,
)
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
universal_framework = ComplianceFramework(
framework="csa_ccm",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="A&A-03",
description="Only Azure",
attributes={},
checks={"azure": ["azure_only_check"]},
),
],
)
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks={},
compliance_frameworks=["csa_ccm_4.0"],
provider=self.provider, # "aws" — no checks declared
universal_frameworks={"csa_ccm_4.0": universal_framework},
)
assert result == set()
@@ -442,3 +442,123 @@ class TestComplianceOutput:
)
assert compliance_output.file_extension == ".csv"
class TestComplianceCheckHelperModule:
"""Tests for the new ``compliance_check`` leaf module that hosts
``get_check_compliance``.
This module exists to break the cyclic import chain
``finding -> compliance.compliance -> universal.* -> finding`` that
CodeQL flagged. It must be:
- importable directly without pulling in the universal pipeline
- re-exported by ``compliance.compliance`` for backward compatibility
- the SAME function object, regardless of import path
"""
def test_module_is_importable_directly(self):
"""The helper module must be importable on its own — it is the
leaf used by ``finding.py`` to break the cyclic import chain."""
from prowler.lib.outputs.compliance import compliance_check
assert hasattr(compliance_check, "get_check_compliance")
assert callable(compliance_check.get_check_compliance)
def test_helper_module_only_depends_on_check_models_and_logger(self):
"""The helper must not pull in universal pipeline modules; that
was the whole point of extracting it. Inspecting the module's
own imports keeps it honest without polluting ``sys.modules``."""
import inspect
from prowler.lib.outputs.compliance import compliance_check
source = inspect.getsource(compliance_check)
# Only these two prowler imports are allowed in the leaf module
assert "from prowler.lib.check.models import Check_Report" in source
assert "from prowler.lib.logger import logger" in source
# And NOT these (would re-introduce the cycle):
assert "from prowler.lib.outputs.compliance.universal" not in source
assert "from prowler.lib.outputs.finding" not in source
assert "from prowler.lib.outputs.ocsf" not in source
def test_re_export_from_compliance_compliance(self):
"""``compliance.compliance.get_check_compliance`` must point to
the same function as ``compliance.compliance_check.get_check_compliance``."""
from prowler.lib.outputs.compliance.compliance import (
get_check_compliance as via_compliance,
)
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance as via_helper,
)
assert via_compliance is via_helper
def test_re_export_from_finding_module(self):
"""``finding.get_check_compliance`` must point to the same
function. Test mocks rely on this attribute existing on the
``prowler.lib.outputs.finding`` module."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance as via_helper,
)
from prowler.lib.outputs.finding import get_check_compliance as via_finding
assert via_finding is via_helper
def test_returns_empty_dict_on_unknown_check(self):
"""Sanity test of the function logic via the helper module."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance,
)
finding = mock.MagicMock()
finding.check_metadata.CheckID = "unknown_check_id"
result = get_check_compliance(finding, "aws", {})
assert result == {}
def test_filters_by_provider(self):
"""The function returns frameworks only for the matching provider."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance,
)
compliance_aws = mock.MagicMock(
Framework="CIS",
Version="1.4",
Provider="AWS",
Requirements=[mock.MagicMock(Id="2.1.3")],
)
compliance_azure = mock.MagicMock(
Framework="CIS",
Version="2.0",
Provider="Azure",
Requirements=[mock.MagicMock(Id="9.1")],
)
finding = mock.MagicMock()
finding.check_metadata.CheckID = "shared_check"
bulk = {
"shared_check": mock.MagicMock(
Compliance=[compliance_aws, compliance_azure]
)
}
# Only AWS frameworks come back
result = get_check_compliance(finding, "aws", bulk)
assert "CIS-1.4" in result
assert "CIS-2.0" not in result
def test_returns_empty_dict_on_exception(self):
"""If iteration raises, the function logs the error and returns
an empty dict (defensive behaviour)."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance,
)
# bulk_checks_metadata that raises when accessed → defensive path
class Boom:
def __contains__(self, _key):
raise RuntimeError("boom")
finding = mock.MagicMock()
finding.check_metadata.CheckID = "any"
result = get_check_compliance(finding, "aws", Boom())
assert result == {}
@@ -0,0 +1,244 @@
"""Tests for display_compliance_table dispatch logic.
Validates that each compliance framework name is routed to the correct
table renderer via startswith matching, and that the universal early-return
takes precedence when applicable.
"""
from unittest.mock import patch
import pytest
from prowler.lib.check.compliance_models import (
ComplianceFramework,
OutputsConfig,
TableConfig,
UniversalComplianceRequirement,
)
from prowler.lib.outputs.compliance.compliance import display_compliance_table
MODULE = "prowler.lib.outputs.compliance.compliance"
# Common args shared by every call — the actual values don't matter
# because we mock the downstream renderers.
_COMMON = dict(
findings=[],
bulk_checks_metadata={},
output_filename="out",
output_directory="/tmp",
compliance_overview=False,
)
# ── Dispatch to legacy table renderers ───────────────────────────────
class TestDispatchStartswith:
"""Each framework prefix must route to exactly one renderer."""
@pytest.mark.parametrize(
"framework_name",
[
"cis_1.4_aws",
"cis_2.0_azure",
"cis_3.0_gcp",
"cis_6.0_m365",
"cis_1.10_kubernetes",
],
)
@patch(f"{MODULE}.get_cis_table")
def test_cis_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["ens_rd2022_aws", "ens_rd2022_azure", "ens_rd2022_gcp"],
)
@patch(f"{MODULE}.get_ens_table")
def test_ens_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["mitre_attack_aws", "mitre_attack_azure", "mitre_attack_gcp"],
)
@patch(f"{MODULE}.get_mitre_attack_table")
def test_mitre_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["kisa_isms_p_2023_aws", "kisa_isms_p_2023_korean_aws"],
)
@patch(f"{MODULE}.get_kisa_ismsp_table")
def test_kisa_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
[
"prowler_threatscore_aws",
"prowler_threatscore_azure",
"prowler_threatscore_gcp",
"prowler_threatscore_kubernetes",
"prowler_threatscore_m365",
"prowler_threatscore_alibabacloud",
],
)
@patch(f"{MODULE}.get_prowler_threatscore_table")
def test_threatscore_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
[
"csa_ccm_4.0_aws",
"csa_ccm_4.0_azure",
"csa_ccm_4.0_gcp",
"csa_ccm_4.0_oraclecloud",
"csa_ccm_4.0_alibabacloud",
],
)
@patch(f"{MODULE}.get_csa_table")
def test_csa_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["c5_aws", "c5_azure", "c5_gcp"],
)
@patch(f"{MODULE}.get_c5_table")
def test_c5_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
[
"soc2_aws",
"hipaa_aws",
"gdpr_aws",
"nist_800_53_revision_4_aws",
"pci_3.2.1_aws",
"iso27001_2013_aws",
"aws_well_architected_framework_security_pillar_aws",
"fedramp_low_revision_4_aws",
"cisa_aws",
],
)
@patch(f"{MODULE}.get_generic_compliance_table")
def test_generic_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
# ── No false matches (the old `in` bug) ─────────────────────────────
class TestNoFalseSubstringMatches:
"""Frameworks that previously could false-match with `in` must NOT
be routed to the wrong renderer now that we use startswith."""
@patch(f"{MODULE}.get_ens_table")
@patch(f"{MODULE}.get_generic_compliance_table")
def test_cisa_does_not_match_cis(self, mock_generic, mock_cis):
"""'cisa_aws' must NOT match startswith('cis_')."""
display_compliance_table(compliance_framework="cisa_aws", **_COMMON)
mock_generic.assert_called_once()
mock_cis.assert_not_called()
@patch(f"{MODULE}.get_prowler_threatscore_table")
@patch(f"{MODULE}.get_generic_compliance_table")
def test_threatscore_prefix_not_partial(self, mock_generic, mock_ts):
"""A hypothetical 'threatscore_custom_aws' must NOT match
startswith('prowler_threatscore_')."""
display_compliance_table(
compliance_framework="threatscore_custom_aws", **_COMMON
)
mock_generic.assert_called_once()
mock_ts.assert_not_called()
@patch(f"{MODULE}.get_ens_table")
@patch(f"{MODULE}.get_prowler_threatscore_table")
def test_prowler_threatscore_does_not_match_ens(self, mock_ts, mock_ens):
"""'prowler_threatscore_aws' must hit threatscore, never ens."""
display_compliance_table(
compliance_framework="prowler_threatscore_aws", **_COMMON
)
mock_ts.assert_called_once()
mock_ens.assert_not_called()
# ── Universal early-return ───────────────────────────────────────────
class TestUniversalEarlyReturn:
"""The universal path must take precedence over the elif chain."""
@staticmethod
def _make_fw():
return ComplianceFramework(
framework="CIS",
name="CIS",
provider="AWS",
version="5.0",
description="d",
requirements=[
UniversalComplianceRequirement(
id="1.1",
description="d",
attributes={},
checks={"aws": ["check_a"]},
),
],
outputs=OutputsConfig(table_config=TableConfig(group_by="_default")),
)
@patch(f"{MODULE}.get_universal_table")
@patch(f"{MODULE}.get_cis_table")
def test_universal_takes_precedence_over_cis(self, mock_cis, mock_universal):
"""A CIS framework in universal_frameworks with TableConfig must
use the universal renderer, not get_cis_table."""
fw = self._make_fw()
display_compliance_table(
compliance_framework="cis_5.0_aws",
universal_frameworks={"cis_5.0_aws": fw},
**_COMMON,
)
mock_universal.assert_called_once()
mock_cis.assert_not_called()
@patch(f"{MODULE}.get_universal_table")
@patch(f"{MODULE}.get_cis_table")
def test_falls_through_without_table_config(self, mock_cis, mock_universal):
"""If the universal framework has no TableConfig, fall through
to the legacy elif chain."""
fw = self._make_fw()
fw.outputs = None
display_compliance_table(
compliance_framework="cis_5.0_aws",
universal_frameworks={"cis_5.0_aws": fw},
**_COMMON,
)
mock_cis.assert_called_once()
mock_universal.assert_not_called()
@patch(f"{MODULE}.get_universal_table")
@patch(f"{MODULE}.get_generic_compliance_table")
def test_falls_through_when_not_in_universal_dict(
self, mock_generic, mock_universal
):
"""If universal_frameworks is empty, fall through to legacy."""
display_compliance_table(
compliance_framework="soc2_aws",
universal_frameworks={},
**_COMMON,
)
mock_generic.assert_called_once()
mock_universal.assert_not_called()
@@ -0,0 +1,730 @@
"""Tests for process_universal_compliance_frameworks and --list-compliance fixes.
Validates that the pre-processing step:
- generates both CSV and OCSF outputs for universal frameworks
- always generates OCSF (no output-format gate)
- skips frameworks without outputs or table_config
- skips frameworks not in universal_frameworks
- returns the set of processed names for removal from the legacy loop
- works across different providers
Also validates that print_compliance_frameworks and print_compliance_requirements
work with universal ComplianceFramework objects (dict checks, None provider).
"""
import json
import os
from datetime import datetime, timezone
from types import SimpleNamespace
import pytest
from prowler.lib.check.check import (
print_compliance_frameworks,
print_compliance_requirements,
)
from prowler.lib.check.compliance_models import (
AttributeMetadata,
ComplianceFramework,
OutputsConfig,
TableConfig,
UniversalComplianceRequirement,
)
from prowler.lib.outputs.compliance.compliance import (
process_universal_compliance_frameworks,
)
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
OCSFComplianceOutput,
)
from prowler.lib.outputs.compliance.universal.universal_output import (
UniversalComplianceOutput,
)
@pytest.fixture(autouse=True)
def _create_compliance_dir(tmp_path):
"""Ensure the compliance/ subdirectory exists before each test."""
os.makedirs(tmp_path / "compliance", exist_ok=True)
# ── Helpers ──────────────────────────────────────────────────────────
def _make_finding(check_id, status="PASS", provider="aws"):
"""Create a mock Finding with all fields needed by both output classes."""
finding = SimpleNamespace()
finding.provider = provider
finding.account_uid = "123456789012"
finding.account_name = "test-account"
finding.account_email = ""
finding.account_organization_uid = "org-123"
finding.account_organization_name = "test-org"
finding.account_tags = {"env": "test"}
finding.region = "us-east-1"
finding.status = status
finding.status_extended = f"{check_id} is {status}"
finding.resource_uid = f"arn:aws:iam::123456789012:{check_id}"
finding.resource_name = check_id
finding.resource_details = "some details"
finding.resource_metadata = {}
finding.resource_tags = {"Name": "test"}
finding.partition = "aws"
finding.muted = False
finding.check_id = check_id
finding.uid = "test-finding-uid"
finding.timestamp = datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
finding.prowler_version = "5.0.0"
finding.compliance = {"TestFW-1.0": ["1.1"]}
finding.metadata = SimpleNamespace(
Provider=provider,
CheckID=check_id,
CheckTitle=f"Title for {check_id}",
CheckType=["test-type"],
Description=f"Description for {check_id}",
Severity="medium",
ServiceName="iam",
ResourceType="aws-iam-role",
Risk="test-risk",
RelatedUrl="https://example.com",
Remediation=SimpleNamespace(
Recommendation=SimpleNamespace(Text="Fix it", Url="https://fix.com"),
),
DependsOn=[],
RelatedTo=[],
Categories=["test"],
Notes="",
AdditionalURLs=[],
)
return finding
def _make_universal_framework(name="TestFW", version="1.0", with_table_config=True):
"""Build a ComplianceFramework with optional table_config."""
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="Test requirement",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
]
metadata = [AttributeMetadata(key="Section", type="str")]
outputs = None
if with_table_config:
outputs = OutputsConfig(table_config=TableConfig(group_by="Section"))
return ComplianceFramework(
framework=name,
name=f"{name} Framework",
provider="AWS",
version=version,
description="Test framework",
requirements=reqs,
attributes_metadata=metadata,
outputs=outputs,
)
# ── Tests ────────────────────────────────────────────────────────────
class TestProcessUniversalComplianceFrameworks:
"""Core tests for the extracted pre-processing function."""
def test_generates_csv_and_ocsf_outputs(self, tmp_path):
"""Both CSV and OCSF outputs are appended to generated_outputs."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
assert len(generated["compliance"]) == 2
assert isinstance(generated["compliance"][0], UniversalComplianceOutput)
assert isinstance(generated["compliance"][1], OCSFComplianceOutput)
def test_ocsf_always_generated_no_format_gate(self, tmp_path):
"""OCSF output is generated regardless of output_formats — no gate."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
ocsf_outputs = [
o for o in generated["compliance"] if isinstance(o, OCSFComplianceOutput)
]
assert len(ocsf_outputs) == 1
def test_csv_file_written(self, tmp_path):
"""CSV file is created with expected content."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
csv_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.csv"
assert csv_path.exists()
content = csv_path.read_text()
assert "PROVIDER" in content
assert "REQUIREMENTS_ATTRIBUTES_SECTION" in content
def test_ocsf_file_written(self, tmp_path):
"""OCSF JSON file is created with valid content."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
ocsf_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.ocsf.json"
assert ocsf_path.exists()
data = json.loads(ocsf_path.read_text())
assert isinstance(data, list)
assert len(data) >= 1
assert data[0]["class_uid"] == 2003
def test_returns_processed_names(self, tmp_path):
"""Returns the set of framework names that were processed."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0", "legacy_fw"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
assert "legacy_fw" not in processed
class TestSkipConditions:
"""Tests for frameworks that should NOT be processed."""
def test_skips_framework_not_in_universal(self, tmp_path):
"""Frameworks not in universal_frameworks dict are skipped."""
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"cis_aws_1.4"},
universal_frameworks={},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
def test_skips_framework_without_outputs(self, tmp_path):
"""Frameworks with outputs=None are skipped."""
fw = _make_universal_framework(with_table_config=False)
# outputs is None since with_table_config=False
assert fw.outputs is None
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
def test_skips_framework_with_outputs_but_no_table_config(self, tmp_path):
"""Frameworks with outputs but table_config=None are skipped."""
fw = _make_universal_framework()
# Manually set table_config to None while keeping outputs
fw.outputs = OutputsConfig(table_config=None)
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
def test_empty_input_frameworks(self, tmp_path):
"""No processing when input set is empty."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks=set(),
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
class TestMixedFrameworks:
"""Tests with a mix of universal and legacy frameworks."""
def test_only_universal_processed_legacy_untouched(self, tmp_path):
"""Only universal frameworks are processed; legacy names are not returned."""
universal_fw = _make_universal_framework()
generated = {"compliance": []}
all_frameworks = {"test_fw_1.0", "cis_aws_1.4", "nist_800_53_aws"}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks=all_frameworks,
universal_frameworks={"test_fw_1.0": universal_fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
# 2 outputs for the one universal framework (CSV + OCSF)
assert len(generated["compliance"]) == 2
def test_removal_from_input_set(self, tmp_path):
"""Caller can subtract processed set from input to get legacy-only frameworks."""
universal_fw = _make_universal_framework()
generated = {"compliance": []}
input_frameworks = {"test_fw_1.0", "cis_aws_1.4", "nist_800_53_aws"}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks=input_frameworks,
universal_frameworks={"test_fw_1.0": universal_fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
remaining = input_frameworks - processed
assert remaining == {"cis_aws_1.4", "nist_800_53_aws"}
def test_multiple_universal_frameworks(self, tmp_path):
"""Multiple universal frameworks each get CSV + OCSF."""
fw1 = _make_universal_framework(name="FW1", version="1.0")
fw2 = _make_universal_framework(name="FW2", version="2.0")
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"fw1_1.0", "fw2_2.0", "legacy"},
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"fw1_1.0", "fw2_2.0"}
# 2 frameworks × 2 outputs each = 4
assert len(generated["compliance"]) == 4
csv_outputs = [
o
for o in generated["compliance"]
if isinstance(o, UniversalComplianceOutput)
]
ocsf_outputs = [
o for o in generated["compliance"] if isinstance(o, OCSFComplianceOutput)
]
assert len(csv_outputs) == 2
assert len(ocsf_outputs) == 2
class TestProviderVariants:
"""Verify the function works for different providers."""
@pytest.mark.parametrize(
"provider",
[
"aws",
"azure",
"gcp",
"kubernetes",
"m365",
"github",
"oraclecloud",
"alibabacloud",
"nhn",
],
)
def test_all_providers_produce_outputs(self, tmp_path, provider):
"""Each provider generates CSV + OCSF when given a universal framework."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a", provider=provider)],
output_directory=str(tmp_path),
output_filename="out",
provider=provider,
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
assert len(generated["compliance"]) == 2
assert isinstance(generated["compliance"][0], UniversalComplianceOutput)
assert isinstance(generated["compliance"][1], OCSFComplianceOutput)
class TestEmptyFindings:
"""Test behavior when there are no findings."""
def test_still_processed_with_empty_findings(self, tmp_path):
"""Framework is still marked as processed even with no findings."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
# Outputs are still appended (they'll just have empty data)
assert len(generated["compliance"]) == 2
class TestFilePaths:
"""Verify correct file path construction."""
def test_csv_path_format(self, tmp_path):
"""CSV output has the correct file path."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"csa_ccm_4.0"},
universal_frameworks={"csa_ccm_4.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_report",
provider="aws",
generated_outputs=generated,
)
csv_output = generated["compliance"][0]
assert csv_output.file_path == (
f"{tmp_path}/compliance/prowler_report_csa_ccm_4.0.csv"
)
def test_ocsf_path_format(self, tmp_path):
"""OCSF output has the correct file path."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"csa_ccm_4.0"},
universal_frameworks={"csa_ccm_4.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_report",
provider="aws",
generated_outputs=generated,
)
ocsf_output = generated["compliance"][1]
assert ocsf_output.file_path == (
f"{tmp_path}/compliance/prowler_report_csa_ccm_4.0.ocsf.json"
)
# ── Tests for --list-compliance fix ──────────────────────────────────
def _make_legacy_compliance():
"""Create a mock legacy Compliance-like object with the expected attributes."""
return SimpleNamespace(
Framework="CIS",
Provider="AWS",
Version="1.4",
Requirements=[
SimpleNamespace(
Id="2.1.3",
Description="Ensure MFA Delete is enabled",
Checks=["s3_bucket_mfa_delete"],
),
],
)
class TestPrintComplianceFrameworks:
"""Tests for print_compliance_frameworks with universal frameworks."""
def test_includes_universal_frameworks(self, capsys):
"""Universal frameworks appear in the listing."""
legacy = {"cis_1.4_aws": _make_legacy_compliance()}
universal = {"csa_ccm_4.0": _make_universal_framework()}
merged = {**legacy, **universal}
print_compliance_frameworks(merged)
captured = capsys.readouterr().out
assert "cis_1.4_aws" in captured
assert "csa_ccm_4.0" in captured
def test_count_includes_both(self, capsys):
"""Framework count includes both legacy and universal."""
legacy = {"cis_1.4_aws": _make_legacy_compliance()}
universal = {"csa_ccm_4.0": _make_universal_framework()}
merged = {**legacy, **universal}
print_compliance_frameworks(merged)
captured = capsys.readouterr().out
assert "2" in captured
def test_universal_only(self, capsys):
"""Works when only universal frameworks are present."""
universal = {"csa_ccm_4.0": _make_universal_framework()}
print_compliance_frameworks(universal)
captured = capsys.readouterr().out
assert "csa_ccm_4.0" in captured
assert "1" in captured
class TestPrintComplianceRequirements:
"""Tests for print_compliance_requirements with universal frameworks."""
def test_list_checks_universal_framework(self, capsys):
"""Requirements with dict checks are printed correctly."""
fw = _make_universal_framework()
all_fw = {"test_fw_1.0": fw}
print_compliance_requirements(all_fw, ["test_fw_1.0"])
captured = capsys.readouterr().out
assert "1.1" in captured
assert "check_a" in captured
def test_dict_checks_universal_framework(self, capsys):
"""Requirements with dict checks show provider-prefixed checks."""
reqs = [
UniversalComplianceRequirement(
id="A&A-01",
description="Audit & Assurance",
attributes={"Section": "A&A"},
checks={"aws": ["check_a", "check_b"], "azure": ["check_c"]},
),
]
fw = ComplianceFramework(
framework="CSA_CCM",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=reqs,
)
all_fw = {"csa_ccm_4.0": fw}
print_compliance_requirements(all_fw, ["csa_ccm_4.0"])
captured = capsys.readouterr().out
assert "A&A-01" in captured
assert "[aws] check_a" in captured
assert "[aws] check_b" in captured
assert "[azure] check_c" in captured
def test_none_provider_shows_multi_provider(self, capsys):
"""Frameworks with provider=None show 'Multi-provider'."""
fw = ComplianceFramework(
framework="CSA_CCM",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={},
checks={"aws": ["check_a"]},
),
],
)
all_fw = {"csa_ccm_4.0": fw}
print_compliance_requirements(all_fw, ["csa_ccm_4.0"])
captured = capsys.readouterr().out
assert "Multi-provider" in captured
# ── Idempotency tests ────────────────────────────────────────────────
class TestIdempotency:
"""The function must be safe to invoke multiple times for the same
framework. Repeated calls must reuse writers tracked in
``generated_outputs["compliance"]`` instead of recreating them.
This guards against:
- duplicate writer entries in generated_outputs (regular pipeline
treats one writer per framework)
- the OCSF append-bug where a second writer would emit
``[...]<new>...]`` and break the JSON array.
"""
def test_second_call_does_not_duplicate_writers(self, tmp_path):
fw = _make_universal_framework()
generated = {"compliance": []}
kwargs = dict(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
first = process_universal_compliance_frameworks(**kwargs)
first_count = len(generated["compliance"])
second = process_universal_compliance_frameworks(**kwargs)
second_count = len(generated["compliance"])
assert first == {"test_fw_1.0"}
assert second == {"test_fw_1.0"} # still reported as processed
assert first_count == 2 # CSV + OCSF
assert second_count == 2 # NO duplication
def test_second_call_keeps_ocsf_json_valid(self, tmp_path):
"""End-to-end: after two calls the OCSF JSON file must still be
a single, valid JSON array — not the broken ``[...]...]`` form."""
fw = _make_universal_framework()
generated = {"compliance": []}
kwargs = dict(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
process_universal_compliance_frameworks(**kwargs)
process_universal_compliance_frameworks(**kwargs)
ocsf_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.ocsf.json"
data = json.loads(ocsf_path.read_text()) # Will raise on invalid JSON
assert isinstance(data, list)
assert len(data) >= 1
def test_reuses_existing_writer_object(self, tmp_path):
"""The CSV/OCSF writer instances appended on first call must be
the SAME objects after a second call — not fresh ones."""
fw = _make_universal_framework()
generated = {"compliance": []}
kwargs = dict(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
process_universal_compliance_frameworks(**kwargs)
first_writers = list(generated["compliance"])
process_universal_compliance_frameworks(**kwargs)
second_writers = list(generated["compliance"])
# Same identity, same length — reused, not recreated.
assert len(first_writers) == len(second_writers)
for a, b in zip(first_writers, second_writers):
assert a is b
def test_idempotency_across_mixed_frameworks(self, tmp_path):
"""When the second call adds a new framework, the new one is
created while existing ones are NOT recreated."""
fw1 = _make_universal_framework(name="FW1", version="1.0")
fw2 = _make_universal_framework(name="FW2", version="2.0")
generated = {"compliance": []}
# First call: only FW1
process_universal_compliance_frameworks(
input_compliance_frameworks={"fw1_1.0"},
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
first_writers = list(generated["compliance"])
assert len(first_writers) == 2
# Second call: includes both. FW1 must be reused, FW2 created fresh.
process_universal_compliance_frameworks(
input_compliance_frameworks={"fw1_1.0", "fw2_2.0"},
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
second_writers = list(generated["compliance"])
assert len(second_writers) == 4 # 2 (FW1 reused) + 2 new (FW2)
# FW1 writer instances unchanged
assert second_writers[0] is first_writers[0]
assert second_writers[1] is first_writers[1]
@@ -2,6 +2,7 @@ import json
from datetime import datetime, timezone
from types import SimpleNamespace
from py_ocsf_models.events.base_event import StatusID as EventStatusID
from py_ocsf_models.events.findings.compliance_finding import ComplianceFinding
from py_ocsf_models.events.findings.compliance_finding_type_id import (
ComplianceFindingTypeID,
@@ -18,6 +19,7 @@ from prowler.lib.check.compliance_models import (
)
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
OCSFComplianceOutput,
_sanitize_resource_data,
)
@@ -473,3 +475,159 @@ class TestOCSFComplianceOutput:
cf = output.data[0]
assert cf.unmapped["requirement_attributes"]["section"] == "Logging"
assert "internal_note" not in cf.unmapped["requirement_attributes"]
class TestSanitizeResourceData:
"""Unit tests for the _sanitize_resource_data helper.
Service resources may carry non-JSON-serializable objects (e.g. raw
Pydantic models such as ``Trail`` or ``LifecyclePolicy``). The helper
must convert them so the resulting ComplianceFinding can be serialized.
"""
def test_dict_passthrough(self):
result = _sanitize_resource_data("details", {"a": 1, "b": "two"})
assert result == {"details": "details", "metadata": {"a": 1, "b": "two"}}
def test_none_metadata(self):
result = _sanitize_resource_data("details", None)
assert result == {"details": "details", "metadata": None}
def test_pydantic_v2_model_dump(self):
class FakeV2Model:
def model_dump(self):
return {"name": "trail-1", "region": "us-east-1"}
result = _sanitize_resource_data("d", {"trail": FakeV2Model()})
assert result["metadata"]["trail"] == {
"name": "trail-1",
"region": "us-east-1",
}
def test_pydantic_v1_dict(self):
class FakeV1Model:
def dict(self):
return {"name": "policy-1", "schedule": "daily"}
result = _sanitize_resource_data("d", {"policy": FakeV1Model()})
assert result["metadata"]["policy"] == {
"name": "policy-1",
"schedule": "daily",
}
def test_nested_pydantic_in_list(self):
class FakeModel:
def model_dump(self):
return {"id": "x"}
result = _sanitize_resource_data("d", {"items": [FakeModel(), FakeModel()]})
assert result["metadata"]["items"] == [{"id": "x"}, {"id": "x"}]
def test_nested_dict_recursion(self):
class FakeInner:
def model_dump(self):
return {"k": "v"}
result = _sanitize_resource_data(
"d", {"outer": {"inner": FakeInner(), "x": [1, 2]}}
)
assert result["metadata"]["outer"]["inner"] == {"k": "v"}
assert result["metadata"]["outer"]["x"] == [1, 2]
def test_tuple_to_list(self):
result = _sanitize_resource_data("d", {"t": (1, 2, "three")})
assert result["metadata"]["t"] == [1, 2, "three"]
def test_non_string_dict_keys_coerced(self):
result = _sanitize_resource_data("d", {1: "a", 2: "b"})
assert result["metadata"] == {"1": "a", "2": "b"}
def test_unknown_object_falls_back_to_str(self):
class Opaque:
def __str__(self):
return "opaque-repr"
result = _sanitize_resource_data("d", {"thing": Opaque()})
assert result["metadata"]["thing"] == "opaque-repr"
def test_circular_reference_falls_back_to_empty(self):
a = {}
a["self"] = a
# json.dumps raises ValueError on recursion → fallback to empty metadata
result = _sanitize_resource_data("d", a)
assert result == {"details": "d", "metadata": {}}
def test_serializes_via_full_finding_pipeline(self):
"""End-to-end: a finding with a non-serializable resource_metadata
produces a JSON-serializable ComplianceFinding."""
class TrailLike:
def __init__(self):
self.name = "trail-A"
self.kms_key_id = "arn:aws:kms:..."
def model_dump(self):
return {"name": self.name, "kms_key_id": self.kms_key_id}
finding = _make_finding("check_a")
finding.resource_metadata = {"trail": TrailLike()}
req = _simple_requirement()
fw = _make_framework([req])
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
# Serialize the resulting ComplianceFinding — must NOT raise
cf = output.data[0]
if hasattr(cf, "model_dump_json"):
json_output = cf.model_dump_json(exclude_none=True)
else:
json_output = cf.json(exclude_none=True)
payload = json.loads(json_output)
# Confirm the trail object made it through as a plain dict
assert payload["resources"][0]["data"]["metadata"]["trail"]["name"] == "trail-A"
class TestEventStatusInline:
"""Tests for the inlined event_status logic that replaced
OCSF.get_finding_status_id() to break the cyclic import."""
def test_unmuted_finding_status_new(self):
finding = _make_finding("check_a")
finding.muted = False
req = _simple_requirement()
fw = _make_framework([req])
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
cf = output.data[0]
assert cf.status_id == EventStatusID.New.value
assert cf.status == EventStatusID.New.name
def test_muted_finding_status_suppressed(self):
finding = _make_finding("check_a")
finding.muted = True
req = _simple_requirement()
fw = _make_framework([req])
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
cf = output.data[0]
assert cf.status_id == EventStatusID.Suppressed.value
assert cf.status == EventStatusID.Suppressed.name
class TestNoTopLevelOCSFImport:
"""Regression test: the top-level OCSF/Finding imports were removed
to break the CodeQL cyclic-import warnings. Ensure they stay out of
the runtime namespace of the module (TYPE_CHECKING block only)."""
def test_finding_not_in_runtime_namespace(self):
import prowler.lib.outputs.compliance.universal.ocsf_compliance as mod
assert "Finding" not in dir(mod)
def test_ocsf_class_not_imported(self):
import prowler.lib.outputs.compliance.universal.ocsf_compliance as mod
assert "OCSF" not in dir(mod)
@@ -0,0 +1,568 @@
from types import SimpleNamespace
from prowler.lib.check.compliance_models import (
AttributeMetadata,
ComplianceFramework,
OutputFormats,
OutputsConfig,
TableConfig,
UniversalComplianceRequirement,
)
from prowler.lib.outputs.compliance.universal.universal_output import (
UniversalComplianceOutput,
)
def _make_finding(check_id, status="PASS", compliance_map=None):
"""Create a mock Finding for output tests."""
finding = SimpleNamespace()
finding.provider = "aws"
finding.account_uid = "123456789012"
finding.account_name = "test-account"
finding.region = "us-east-1"
finding.status = status
finding.status_extended = f"{check_id} is {status}"
finding.resource_uid = f"arn:aws:iam::123456789012:{check_id}"
finding.resource_name = check_id
finding.muted = False
finding.check_id = check_id
finding.metadata = SimpleNamespace(
Provider="aws",
CheckID=check_id,
Severity="medium",
)
finding.compliance = compliance_map or {}
return finding
def _make_framework(requirements, attrs_metadata=None, table_config=None):
return ComplianceFramework(
framework="TestFW",
name="Test Framework",
provider="AWS",
version="1.0",
description="Test framework",
requirements=requirements,
attributes_metadata=attrs_metadata,
outputs=OutputsConfig(table_config=table_config) if table_config else None,
)
class TestDynamicCSVColumns:
def test_columns_match_metadata(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM", "SubSection": "Auth"},
checks={"aws": ["check_a"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
AttributeMetadata(key="SubSection", type="str"),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
assert len(output.data) == 1
row_dict = output.data[0].dict()
assert "Requirements_Attributes_Section" in row_dict
assert "Requirements_Attributes_SubSection" in row_dict
assert row_dict["Requirements_Attributes_Section"] == "IAM"
assert row_dict["Requirements_Attributes_SubSection"] == "Auth"
class TestManualRequirements:
def test_manual_status(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
UniversalComplianceRequirement(
id="manual-1",
description="manual check",
attributes={"Section": "Governance"},
checks={},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
# Should have 1 real finding + 1 manual
assert len(output.data) == 2
manual_rows = [r for r in output.data if r.dict()["Status"] == "MANUAL"]
assert len(manual_rows) == 1
assert manual_rows[0].dict()["Requirements_Id"] == "manual-1"
assert manual_rows[0].dict()["ResourceId"] == "manual_check"
class TestMITREExtraColumns:
def test_mitre_columns_present(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="T1190",
description="Exploit",
attributes={},
checks={"aws": ["check_a"]},
tactics=["Initial Access"],
sub_techniques=[],
platforms=["IaaS"],
technique_url="https://attack.mitre.org/techniques/T1190/",
),
]
fw = _make_framework(reqs, None, TableConfig(group_by="_Tactics"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["T1190"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
assert len(output.data) == 1
row_dict = output.data[0].dict()
assert "Requirements_Tactics" in row_dict
assert row_dict["Requirements_Tactics"] == "Initial Access"
assert "Requirements_TechniqueURL" in row_dict
class TestCSVFileWrite:
def test_batch_write(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
output.batch_write_data_to_file()
# Verify file was created and has content
with open(filepath, "r") as f:
content = f.read()
assert "PROVIDER" in content # Headers are uppercase
assert "REQUIREMENTS_ATTRIBUTES_SECTION" in content
assert "IAM" in content
class TestNoFindings:
def test_empty_findings_no_data(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
]
fw = _make_framework(reqs, None, TableConfig(group_by="Section"))
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=[],
framework=fw,
file_path=filepath,
)
assert len(output.data) == 0
class TestMultiProviderOutput:
def test_dict_checks_filtered_by_provider(self, tmp_path):
"""Only checks for the given provider appear in CSV output."""
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"], "azure": ["check_b"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = ComplianceFramework(
framework="MultiCloud",
name="Multi",
version="1.0",
description="Test multi-provider",
requirements=reqs,
attributes_metadata=metadata,
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
)
findings = [
_make_finding("check_a", "PASS", {"MultiCloud-1.0": ["1.1"]}),
_make_finding("check_b", "FAIL", {"MultiCloud-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
provider="aws",
)
# Only check_a should match (it's the AWS check)
assert len(output.data) == 1
row_dict = output.data[0].dict()
assert row_dict["Requirements_Attributes_Section"] == "IAM"
def test_no_provider_includes_all(self, tmp_path):
"""Without provider filter, all checks from all providers are included."""
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"], "azure": ["check_b"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = ComplianceFramework(
framework="MultiCloud",
name="Multi",
version="1.0",
description="Test multi-provider",
requirements=reqs,
attributes_metadata=metadata,
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
)
findings = [
_make_finding("check_a", "PASS", {"MultiCloud-1.0": ["1.1"]}),
_make_finding("check_b", "FAIL", {"MultiCloud-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
# Both checks should be included without provider filter
assert len(output.data) == 2
def test_empty_dict_checks_is_manual(self, tmp_path):
"""Requirement with empty dict checks is treated as manual."""
reqs = [
UniversalComplianceRequirement(
id="manual-1",
description="manual check",
attributes={"Section": "Governance"},
checks={},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = ComplianceFramework(
framework="MultiCloud",
name="Multi",
version="1.0",
description="Test",
requirements=reqs,
attributes_metadata=metadata,
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
)
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=[_make_finding("other_check", "PASS", {})],
framework=fw,
file_path=filepath,
provider="aws",
)
manual_rows = [r for r in output.data if r.dict()["Status"] == "MANUAL"]
assert len(manual_rows) == 1
assert manual_rows[0].dict()["Requirements_Id"] == "manual-1"
class TestCSVExclude:
def test_csv_false_excludes_column(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM", "Internal": "hidden"},
checks={"aws": ["check_a"]},
),
]
metadata = [
AttributeMetadata(
key="Section", type="str", output_formats=OutputFormats(csv=True)
),
AttributeMetadata(
key="Internal", type="str", output_formats=OutputFormats(csv=False)
),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
row_dict = output.data[0].dict()
assert "Requirements_Attributes_Section" in row_dict
assert "Requirements_Attributes_Internal" not in row_dict
def _make_provider_finding(provider, check_id="check_a", status="PASS"):
"""Create a mock Finding with a specific provider."""
finding = _make_finding(check_id, status, {"TestFW-1.0": ["1.1"]})
finding.provider = provider
return finding
def _simple_framework():
all_providers = [
"aws",
"azure",
"gcp",
"kubernetes",
"m365",
"github",
"oraclecloud",
"alibabacloud",
"nhn",
"unknown",
]
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={p: ["check_a"] for p in all_providers},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
return _make_framework(reqs, metadata, TableConfig(group_by="Section"))
class TestProviderHeaders:
def test_aws_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("aws")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="aws",
)
row_dict = output.data[0].dict()
assert "AccountId" in row_dict
assert "Region" in row_dict
assert row_dict["AccountId"] == "123456789012"
assert row_dict["Region"] == "us-east-1"
def test_azure_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("azure")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="azure",
)
row_dict = output.data[0].dict()
assert "SubscriptionId" in row_dict
assert "Location" in row_dict
assert row_dict["SubscriptionId"] == "123456789012"
assert row_dict["Location"] == "us-east-1"
def test_gcp_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("gcp")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="gcp",
)
row_dict = output.data[0].dict()
assert "ProjectId" in row_dict
assert "Location" in row_dict
assert row_dict["ProjectId"] == "123456789012"
def test_kubernetes_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("kubernetes")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="kubernetes",
)
row_dict = output.data[0].dict()
assert "Context" in row_dict
assert "Namespace" in row_dict
# Kubernetes Context maps to account_name
assert row_dict["Context"] == "test-account"
assert row_dict["Namespace"] == "us-east-1"
def test_github_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("github")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="github",
)
row_dict = output.data[0].dict()
assert "Account_Name" in row_dict
assert "Account_Id" in row_dict
# GitHub: Account_Name (pos 3) from account_name, Account_Id (pos 4) from account_uid
assert row_dict["Account_Name"] == "test-account"
assert row_dict["Account_Id"] == "123456789012"
# Verify column order matches legacy (Account_Name before Account_Id)
keys = list(row_dict.keys())
assert keys.index("Account_Name") < keys.index("Account_Id")
def test_unknown_provider_defaults(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("unknown")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="unknown",
)
row_dict = output.data[0].dict()
assert "AccountId" in row_dict
assert "Region" in row_dict
def test_none_provider_defaults(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("aws")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
)
row_dict = output.data[0].dict()
assert "AccountId" in row_dict
assert "Region" in row_dict
def test_csv_write_azure_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("azure")]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
provider="azure",
)
output.batch_write_data_to_file()
with open(filepath, "r") as f:
content = f.read()
assert "SUBSCRIPTIONID" in content
assert "LOCATION" in content
# Should NOT have the default AccountId/Region headers
assert "ACCOUNTID" not in content
def test_column_order_matches_legacy(self, tmp_path):
"""Verify that the base column order matches the legacy per-provider models.
Legacy models all define: Provider, Description, <col3>, <col4>, AssessmentDate, ...
The universal output must preserve this exact order for backward compatibility.
"""
# Expected column order per provider (positions 3 and 4 after Provider, Description)
legacy_order = {
"aws": ("AccountId", "Region"),
"azure": ("SubscriptionId", "Location"),
"gcp": ("ProjectId", "Location"),
"kubernetes": ("Context", "Namespace"),
"m365": ("TenantId", "Location"),
"github": ("Account_Name", "Account_Id"),
"oraclecloud": ("TenancyId", "Region"),
"alibabacloud": ("AccountId", "Region"),
"nhn": ("AccountId", "Region"),
}
for provider_name, (expected_col3, expected_col4) in legacy_order.items():
fw = _simple_framework()
findings = [_make_provider_finding(provider_name)]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / f"test_{provider_name}.csv"),
provider=provider_name,
)
keys = list(output.data[0].dict().keys())
assert keys[0] == "Provider", f"{provider_name}: col 1 should be Provider"
assert (
keys[1] == "Description"
), f"{provider_name}: col 2 should be Description"
assert (
keys[2] == expected_col3
), f"{provider_name}: col 3 should be {expected_col3}, got {keys[2]}"
assert (
keys[3] == expected_col4
), f"{provider_name}: col 4 should be {expected_col4}, got {keys[3]}"
assert (
keys[4] == "AssessmentDate"
), f"{provider_name}: col 5 should be AssessmentDate"