Compare commits

...

8 Commits

Author SHA1 Message Date
dependabot[bot] 5e488e2ee6 chore(deps): bump https://github.com/astral-sh/ruff-pre-commit
Bumps [https://github.com/astral-sh/ruff-pre-commit](https://github.com/astral-sh/ruff-pre-commit) from v0.15.11 to 0.15.12.
- [Release notes](https://github.com/astral-sh/ruff-pre-commit/releases)
- [Commits](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.11...v0.15.12)

---
updated-dependencies:
- dependency-name: https://github.com/astral-sh/ruff-pre-commit
  dependency-version: 0.15.12
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-02 01:46:08 +00:00
Hugo Pereira Brito 8db3a89669 ci: remove andoniaf from prowler-cloud (#10926) 2026-04-30 18:07:25 +02:00
Danny Lyubenov c802dc8a36 feat(codebuild): use batched API calls to prevent throttling and false positives (#10639)
Co-authored-by: Daniel Barranquero <danielbo2001@gmail.com>
2026-04-30 17:19:21 +02:00
Pedro Martín 3ab9a4efa5 chore(changelog): update with latest changes (#10948) 2026-04-30 14:13:40 +02:00
Pepe Fagoaga 36b8aa1b79 fix(boto3): pass config to clients (#10944) 2026-04-30 14:11:29 +02:00
Pedro Martín e821e07d7d docs(rbac): add Manage Alerts permission (#10947) 2026-04-30 13:58:17 +02:00
Boon 228fe6d579 feat: add ASD Essential Eight compliance framework for AWS (#10808)
Co-authored-by: Boon <boon@security8.work>
Co-authored-by: pedrooot <pedromarting3@gmail.com>
2026-04-30 13:49:08 +02:00
Pedro Martín 578186aa40 feat(sdk): integrate universal compliance into CLI pipeline (#10301) 2026-04-30 13:49:00 +02:00
47 changed files with 4761 additions and 201 deletions
+1 -1
View File
@@ -62,7 +62,7 @@ jobs:
"Alan-TheGentleman"
"alejandrobailo"
"amitsharm"
"andoniaf"
# "andoniaf"
"cesararroba"
"danibarranqueroo"
"HugoPBrito"
+1 -1
View File
@@ -98,7 +98,7 @@ repos:
## PYTHON — API + MCP Server (ruff)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.11
rev: v0.15.12
hooks:
- id: ruff
name: "API + MCP - ruff check"
@@ -227,6 +227,7 @@ Assign administrative permissions by selecting from the following options:
| Manage Integrations | All | Add or modify the Prowler Integrations. |
| Manage Ingestions | Prowler Cloud | Allow or deny the ability to submit findings ingestion batches via the API. |
| Manage Billing | Prowler Cloud | Access and manage billing settings and subscription information. |
| Manage Alerts | Prowler Cloud | Create, edit, and delete alert rules and recipients. |
<Note>
The **Scope** column indicates where each permission applies. **All** means the permission is available in both Prowler Cloud and Self-Managed deployments. **Prowler Cloud** indicates permissions that are specific to [Prowler Cloud](https://cloud.prowler.com/sign-in).
@@ -241,3 +242,5 @@ The following permissions are available exclusively in **Prowler Cloud**:
**Manage Ingestions:** Submit and manage findings ingestion jobs via the API. Required to upload OCSF scan results using the `--push-to-cloud` CLI flag or the ingestion endpoints. See [Import Findings](/user-guide/tutorials/prowler-app-import-findings) for details.
**Manage Billing:** Access and manage billing settings, subscription plans, and payment methods.
**Manage Alerts:** Create, edit, and delete alert rules and recipients used to deliver scan-result digests via email.
+5
View File
@@ -7,16 +7,21 @@ All notable changes to the **Prowler SDK** are documented in this file.
### 🚀 Added
- `bedrock_guardrails_configured` check for AWS provider [(#10844)](https://github.com/prowler-cloud/prowler/pull/10844)
- Universal compliance pipeline integrated into the CLI: `--list-compliance` and `--list-compliance-requirements` show universal frameworks, and CSV plus OCSF outputs are generated for any framework declaring a `TableConfig` [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301)
- ASD Essential Eight Maturity Model compliance framework for AWS (Maturity Level One, Nov 2023) [(#10808)](https://github.com/prowler-cloud/prowler/pull/10808)
### 🔄 Changed
- `route53_dangling_ip_subdomain_takeover` now also flags `CNAME` records pointing to S3 website endpoints whose buckets are missing from the account [(#10920)](https://github.com/prowler-cloud/prowler/pull/10920)
- Azure Network Watcher flow log checks now require workspace-backed Traffic Analytics for `network_flow_log_captured_sent` and align metadata with VNet-compatible flow log guidance [(#10645)](https://github.com/prowler-cloud/prowler/pull/10645)
- Azure compliance entries for legacy Network Watcher flow log controls now use retirement-aware guidance and point new deployments to VNet flow logs
- AWS CodeBuild service now batches `BatchGetProjects` and `BatchGetBuilds` calls per region (up to 100 items per call) to reduce API call volume and prevent throttling-induced false positives in `codebuild_project_not_publicly_accessible` [(#10639)](https://github.com/prowler-cloud/prowler/pull/10639)
- `display_compliance_table` dispatch switched from substring `in` checks to `startswith` to prevent false matches between similarly named frameworks (e.g. `cisa` vs `cis`) [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301)
### 🐞 Fixed
- AWS SDK test isolation: autouse `mock_aws` fixture and leak detector in `conftest.py` to prevent tests from hitting real AWS endpoints, with idempotent organization setup for tests calling `set_mocked_aws_provider` multiple times [(#10605)](https://github.com/prowler-cloud/prowler/pull/10605)
- AWS `boto` user agent extra is now applied to every client [(#10944)](https://github.com/prowler-cloud/prowler/pull/10944)
### 🔐 Security
+56 -13
View File
@@ -45,7 +45,10 @@ from prowler.lib.check.check import (
)
from prowler.lib.check.checks_loader import load_checks_to_execute
from prowler.lib.check.compliance import update_checks_metadata_with_compliance
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.check.compliance_models import (
Compliance,
get_bulk_compliance_frameworks_universal,
)
from prowler.lib.check.custom_checks_metadata import (
parse_custom_checks_metadata_file,
update_checks_metadata,
@@ -75,7 +78,10 @@ from prowler.lib.outputs.compliance.cis.cis_oraclecloud import OracleCloudCIS
from prowler.lib.outputs.compliance.cisa_scuba.cisa_scuba_googleworkspace import (
GoogleWorkspaceCISASCuBA,
)
from prowler.lib.outputs.compliance.compliance import display_compliance_table
from prowler.lib.outputs.compliance.compliance import (
display_compliance_table,
process_universal_compliance_frameworks,
)
from prowler.lib.outputs.compliance.csa.csa_alibabacloud import AlibabaCloudCSA
from prowler.lib.outputs.compliance.csa.csa_aws import AWSCSA
from prowler.lib.outputs.compliance.csa.csa_azure import AzureCSA
@@ -84,6 +90,9 @@ from prowler.lib.outputs.compliance.csa.csa_oraclecloud import OracleCloudCSA
from prowler.lib.outputs.compliance.ens.ens_aws import AWSENS
from prowler.lib.outputs.compliance.ens.ens_azure import AzureENS
from prowler.lib.outputs.compliance.ens.ens_gcp import GCPENS
from prowler.lib.outputs.compliance.essential_eight.essential_eight_aws import (
EssentialEightAWS,
)
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
from prowler.lib.outputs.compliance.iso27001.iso27001_aws import AWSISO27001
from prowler.lib.outputs.compliance.iso27001.iso27001_azure import AzureISO27001
@@ -235,6 +244,8 @@ def prowler():
# Load compliance frameworks
logger.debug("Loading compliance frameworks from .json files")
universal_frameworks = {}
# Skip compliance frameworks for external-tool providers
if provider not in EXTERNAL_TOOL_PROVIDERS:
bulk_compliance_frameworks = Compliance.get_bulk(provider)
@@ -242,6 +253,8 @@ def prowler():
bulk_checks_metadata = update_checks_metadata_with_compliance(
bulk_compliance_frameworks, bulk_checks_metadata
)
# Load universal compliance frameworks for new rendering pipeline
universal_frameworks = get_bulk_compliance_frameworks_universal(provider)
# Update checks metadata if the --custom-checks-metadata-file is present
custom_checks_metadata = None
@@ -254,12 +267,12 @@ def prowler():
)
if args.list_compliance:
print_compliance_frameworks(bulk_compliance_frameworks)
all_frameworks = {**bulk_compliance_frameworks, **universal_frameworks}
print_compliance_frameworks(all_frameworks)
sys.exit()
if args.list_compliance_requirements:
print_compliance_requirements(
bulk_compliance_frameworks, args.list_compliance_requirements
)
all_frameworks = {**bulk_compliance_frameworks, **universal_frameworks}
print_compliance_requirements(all_frameworks, args.list_compliance_requirements)
sys.exit()
# Load checks to execute
@@ -276,6 +289,7 @@ def prowler():
provider=provider,
list_checks=getattr(args, "list_checks", False)
or getattr(args, "list_checks_json", False),
universal_frameworks=universal_frameworks,
)
# if --list-checks-json, dump a json file and exit
@@ -624,15 +638,29 @@ def prowler():
)
# Compliance Frameworks
# Source the framework listing from `bulk_compliance_frameworks.keys()`
# so it is by construction a subset of what the bulk loader can resolve.
# `get_available_compliance_frameworks(provider)` also discovers top-level
# multi-provider universal JSONs (e.g. `prowler/compliance/csa_ccm_4.0.json`)
# which `Compliance.get_bulk(provider)` does not load, and which the legacy
# output handlers below cannot consume — using it as the source produced
# Source the framework listing from the union of `bulk_compliance_frameworks`
# and `universal_frameworks` so universal-only frameworks (e.g.
# `prowler/compliance/csa_ccm_4.0.json`) — which `Compliance.get_bulk(provider)`
# does not load — still reach `process_universal_compliance_frameworks` below.
# The provider-specific block subtracts the names handled by the universal
# processor so the legacy per-provider handlers only see frameworks that the
# bulk loader actually resolved.
input_compliance_frameworks = set(output_options.output_modes).intersection(
bulk_compliance_frameworks.keys()
set(bulk_compliance_frameworks.keys()) | set(universal_frameworks.keys())
)
# ── Universal compliance frameworks (provider-agnostic) ──
universal_processed = process_universal_compliance_frameworks(
input_compliance_frameworks=input_compliance_frameworks,
universal_frameworks=universal_frameworks,
finding_outputs=finding_outputs,
output_directory=output_options.output_directory,
output_filename=output_options.output_filename,
provider=provider,
generated_outputs=generated_outputs,
)
input_compliance_frameworks -= universal_processed
if provider == "aws":
for compliance_name in input_compliance_frameworks:
if compliance_name.startswith("cis_"):
@@ -648,6 +676,18 @@ def prowler():
)
generated_outputs["compliance"].append(cis)
cis.batch_write_data_to_file()
elif compliance_name.startswith("essential_eight"):
filename = (
f"{output_options.output_directory}/compliance/"
f"{output_options.output_filename}_{compliance_name}.csv"
)
essential_eight = EssentialEightAWS(
findings=finding_outputs,
compliance=bulk_compliance_frameworks[compliance_name],
file_path=filename,
)
generated_outputs["compliance"].append(essential_eight)
essential_eight.batch_write_data_to_file()
elif compliance_name == "mitre_attack_aws":
# Generate MITRE ATT&CK Finding Object
filename = (
@@ -1402,6 +1442,9 @@ def prowler():
output_options.output_filename,
output_options.output_directory,
compliance_overview,
universal_frameworks=universal_frameworks,
provider=provider,
output_formats=args.output_formats,
)
if compliance_overview:
print(
File diff suppressed because it is too large Load Diff
+5 -3
View File
@@ -87,8 +87,8 @@ def get_available_compliance_frameworks(provider=None):
providers = [p.value for p in Provider]
if provider:
providers = [provider]
for provider in providers:
compliance_dir = f"{actual_directory}/../compliance/{provider}"
for current_provider in providers:
compliance_dir = f"{actual_directory}/../compliance/{current_provider}"
if not os.path.isdir(compliance_dir):
continue
with os.scandir(compliance_dir) as files:
@@ -97,7 +97,9 @@ def get_available_compliance_frameworks(provider=None):
available_compliance_frameworks.append(
file.name.removesuffix(".json")
)
# Also scan top-level compliance/ for multi-provider JSONs
# Also scan top-level compliance/ for multi-provider (universal) JSONs.
# When a specific provider was requested, only include the framework if it
# declares support for that provider; otherwise include all universal frameworks.
compliance_root = f"{actual_directory}/../compliance"
if os.path.isdir(compliance_root):
with os.scandir(compliance_root) as files:
+30 -7
View File
@@ -299,12 +299,22 @@ def print_compliance_frameworks(
def print_compliance_requirements(
bulk_compliance_frameworks: dict, compliance_frameworks: list
):
from prowler.lib.check.compliance_models import ComplianceFramework
for compliance_framework in compliance_frameworks:
for key in bulk_compliance_frameworks.keys():
framework = bulk_compliance_frameworks[key].Framework
provider = bulk_compliance_frameworks[key].Provider
version = bulk_compliance_frameworks[key].Version
requirements = bulk_compliance_frameworks[key].Requirements
entry = bulk_compliance_frameworks[key]
is_universal = isinstance(entry, ComplianceFramework)
if is_universal:
framework = entry.framework
provider = entry.provider or "Multi-provider"
version = entry.version
requirements = entry.requirements
else:
framework = entry.Framework
provider = entry.Provider or "Multi-provider"
version = entry.Version
requirements = entry.Requirements
# We can list the compliance requirements for a given framework using the
# bulk_compliance_frameworks keys since they are the compliance specification file name
if compliance_framework == key:
@@ -313,10 +323,23 @@ def print_compliance_requirements(
)
for requirement in requirements:
checks = ""
for check in requirement.Checks:
checks += f" {Fore.YELLOW}\t\t{check}\n{Style.RESET_ALL}"
if is_universal:
req_checks = requirement.checks
req_id = requirement.id
req_description = requirement.description
else:
req_checks = requirement.Checks
req_id = requirement.Id
req_description = requirement.Description
if isinstance(req_checks, dict):
for prov, check_list in req_checks.items():
for check in check_list:
checks += f" {Fore.YELLOW}\t\t[{prov}] {check}\n{Style.RESET_ALL}"
else:
for check in req_checks:
checks += f" {Fore.YELLOW}\t\t{check}\n{Style.RESET_ALL}"
print(
f"Requirement Id: {Fore.MAGENTA}{requirement.Id}{Style.RESET_ALL}\n\t- Description: {requirement.Description}\n\t- Checks:\n{checks}"
f"Requirement Id: {Fore.MAGENTA}{req_id}{Style.RESET_ALL}\n\t- Description: {req_description}\n\t- Checks:\n{checks}"
)
+15 -5
View File
@@ -22,6 +22,7 @@ def load_checks_to_execute(
categories: set = None,
resource_groups: set = None,
list_checks: bool = False,
universal_frameworks: dict = None,
) -> set:
"""Generate the list of checks to execute based on the cloud provider and the input arguments given"""
try:
@@ -155,12 +156,21 @@ def load_checks_to_execute(
if not bulk_compliance_frameworks:
bulk_compliance_frameworks = Compliance.get_bulk(provider=provider)
for compliance_framework in compliance_frameworks:
checks_to_execute.update(
CheckMetadata.list(
bulk_compliance_frameworks=bulk_compliance_frameworks,
compliance_framework=compliance_framework,
# Try universal frameworks first (snake_case dict-keyed checks)
if (
universal_frameworks
and compliance_framework in universal_frameworks
):
fw = universal_frameworks[compliance_framework]
for req in fw.requirements:
checks_to_execute.update(req.checks.get(provider.lower(), []))
elif compliance_framework in bulk_compliance_frameworks:
checks_to_execute.update(
CheckMetadata.list(
bulk_compliance_frameworks=bulk_compliance_frameworks,
compliance_framework=compliance_framework,
)
)
)
# Handle if there are categories passed using --categories
elif categories:
+43
View File
@@ -102,6 +102,48 @@ class CIS_Requirement_Attribute(BaseModel):
References: str
class EssentialEight_Requirement_Attribute_MaturityLevel(str, Enum):
"""ASD Essential Eight Maturity Level"""
ML1 = "ML1"
ML2 = "ML2"
ML3 = "ML3"
class EssentialEight_Requirement_Attribute_AssessmentStatus(str, Enum):
"""Essential Eight Requirement Attribute Assessment Status"""
Manual = "Manual"
Automated = "Automated"
class EssentialEight_Requirement_Attribute_CloudApplicability(str, Enum):
"""How well the ASD control maps to AWS cloud infrastructure."""
Full = "full"
Partial = "partial"
Limited = "limited"
NonApplicable = "non-applicable"
# Essential Eight Requirement Attribute
class EssentialEight_Requirement_Attribute(BaseModel):
"""ASD Essential Eight Requirement Attribute"""
Section: str
MaturityLevel: EssentialEight_Requirement_Attribute_MaturityLevel
AssessmentStatus: EssentialEight_Requirement_Attribute_AssessmentStatus
CloudApplicability: EssentialEight_Requirement_Attribute_CloudApplicability
MitigatedThreats: list[str]
Description: str
RationaleStatement: str
ImpactStatement: str
RemediationProcedure: str
AuditProcedure: str
AdditionalInformation: str
References: str
# Well Architected Requirement Attribute
class AWS_Well_Architected_Requirement_Attribute(BaseModel):
"""AWS Well Architected Requirement Attribute"""
@@ -250,6 +292,7 @@ class Compliance_Requirement(BaseModel):
Name: Optional[str] = None
Attributes: list[
Union[
EssentialEight_Requirement_Attribute,
CIS_Requirement_Attribute,
ENS_Requirement_Attribute,
ISO27001_2013_Requirement_Attribute,
+141 -62
View File
@@ -1,12 +1,17 @@
import sys
from prowler.lib.check.models import Check_Report
from prowler.lib.logger import logger
from prowler.lib.outputs.compliance.c5.c5 import get_c5_table
from prowler.lib.outputs.compliance.ccc.ccc import get_ccc_table
from prowler.lib.outputs.compliance.cis.cis import get_cis_table
from prowler.lib.outputs.compliance.compliance_check import ( # noqa: F401 - re-export for backward compatibility
get_check_compliance,
)
from prowler.lib.outputs.compliance.csa.csa import get_csa_table
from prowler.lib.outputs.compliance.ens.ens import get_ens_table
from prowler.lib.outputs.compliance.essential_eight.essential_eight import (
get_essential_eight_table,
)
from prowler.lib.outputs.compliance.generic.generic_table import (
get_generic_compliance_table,
)
@@ -17,6 +22,94 @@ from prowler.lib.outputs.compliance.mitre_attack.mitre_attack import (
from prowler.lib.outputs.compliance.prowler_threatscore.prowler_threatscore import (
get_prowler_threatscore_table,
)
from prowler.lib.outputs.compliance.universal.universal_table import get_universal_table
def process_universal_compliance_frameworks(
input_compliance_frameworks: set,
universal_frameworks: dict,
finding_outputs: list,
output_directory: str,
output_filename: str,
provider: str,
generated_outputs: dict,
) -> set:
"""Process universal compliance frameworks, generating CSV and OCSF outputs.
For each framework in *input_compliance_frameworks* that exists in
*universal_frameworks* and has an outputs.table_config, this function
creates both a CSV (UniversalComplianceOutput) and an OCSF JSON
(OCSFComplianceOutput) file. OCSF is always generated regardless of
the user's ``--output-formats`` flag.
The function is idempotent: it tracks already-created writers via
``generated_outputs["compliance"]`` keyed by ``file_path``. If invoked
again for the same framework (e.g. once per streaming batch), it
reuses the existing writer instead of recreating it. This guarantees
one output writer per framework for the whole execution and keeps
the OCSF JSON array valid across multiple calls.
Returns the set of framework names that were processed so the caller
can remove them before entering the legacy per-provider output loop.
"""
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
OCSFComplianceOutput,
)
from prowler.lib.outputs.compliance.universal.universal_output import (
UniversalComplianceOutput,
)
existing_writers = {
getattr(out, "file_path", None): out
for out in generated_outputs.get("compliance", [])
if isinstance(out, (UniversalComplianceOutput, OCSFComplianceOutput))
}
processed = set()
for compliance_name in input_compliance_frameworks:
if not (
compliance_name in universal_frameworks
and universal_frameworks[compliance_name].outputs
and universal_frameworks[compliance_name].outputs.table_config
):
continue
fw = universal_frameworks[compliance_name]
# CSV output
csv_path = (
f"{output_directory}/compliance/" f"{output_filename}_{compliance_name}.csv"
)
if csv_path not in existing_writers:
output = UniversalComplianceOutput(
findings=finding_outputs,
framework=fw,
file_path=csv_path,
provider=provider,
)
generated_outputs["compliance"].append(output)
existing_writers[csv_path] = output
output.batch_write_data_to_file()
# OCSF output (always generated for universal frameworks)
ocsf_path = (
f"{output_directory}/compliance/"
f"{output_filename}_{compliance_name}.ocsf.json"
)
if ocsf_path not in existing_writers:
ocsf_output = OCSFComplianceOutput(
findings=finding_outputs,
framework=fw,
file_path=ocsf_path,
provider=provider,
)
generated_outputs["compliance"].append(ocsf_output)
existing_writers[ocsf_path] = ocsf_output
ocsf_output.batch_write_data_to_file()
processed.add(compliance_name)
return processed
def display_compliance_table(
@@ -26,6 +119,9 @@ def display_compliance_table(
output_filename: str,
output_directory: str,
compliance_overview: bool,
universal_frameworks: dict = None,
provider: str = None,
output_formats: list = None,
) -> None:
"""
display_compliance_table generates the compliance table for the given compliance framework.
@@ -37,6 +133,9 @@ def display_compliance_table(
output_filename (str): The output filename
output_directory (str): The output directory
compliance_overview (bool): The compliance
universal_frameworks (dict): Optional universal ComplianceFramework objects
provider (str): The current provider (e.g. "aws") for multi-provider filtering
output_formats (list): The output formats to generate
Returns:
None
@@ -45,16 +144,24 @@ def display_compliance_table(
findings = [f for f in findings if f.check_metadata.CheckID in bulk_checks_metadata]
try:
if "ens_" in compliance_framework:
get_ens_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
)
elif "cis_" in compliance_framework:
# Universal path: if the framework has TableConfig, use the universal renderer
if universal_frameworks and compliance_framework in universal_frameworks:
fw = universal_frameworks[compliance_framework]
if fw.outputs and fw.outputs.table_config:
get_universal_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
framework=fw,
provider=provider,
output_formats=output_formats,
)
return
if compliance_framework.startswith("cis_"):
get_cis_table(
findings,
bulk_checks_metadata,
@@ -63,7 +170,16 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "mitre_attack" in compliance_framework:
elif compliance_framework.startswith("ens_"):
get_ens_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
)
elif compliance_framework.startswith("mitre_attack"):
get_mitre_attack_table(
findings,
bulk_checks_metadata,
@@ -72,7 +188,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "kisa_isms_" in compliance_framework:
elif compliance_framework.startswith("kisa"):
get_kisa_ismsp_table(
findings,
bulk_checks_metadata,
@@ -81,7 +197,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "threatscore_" in compliance_framework:
elif compliance_framework.startswith("prowler_threatscore_"):
get_prowler_threatscore_table(
findings,
bulk_checks_metadata,
@@ -90,7 +206,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "csa_ccm_" in compliance_framework:
elif compliance_framework.startswith("csa_ccm_"):
get_csa_table(
findings,
bulk_checks_metadata,
@@ -99,7 +215,7 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "c5_" in compliance_framework:
elif compliance_framework.startswith("c5_"):
get_c5_table(
findings,
bulk_checks_metadata,
@@ -117,6 +233,15 @@ def display_compliance_table(
output_directory,
compliance_overview,
)
elif "essential_eight" in compliance_framework:
get_essential_eight_table(
findings,
bulk_checks_metadata,
compliance_framework,
output_filename,
output_directory,
compliance_overview,
)
else:
get_generic_compliance_table(
findings,
@@ -131,49 +256,3 @@ def display_compliance_table(
f"{error.__class__.__name__}:{error.__traceback__.tb_lineno} -- {error}"
)
sys.exit(1)
# TODO: this should be in the Check class
def get_check_compliance(
finding: Check_Report, provider_type: str, bulk_checks_metadata: dict
) -> dict:
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
Example:
{
"CIS-1.4": ["2.1.3"],
"CIS-1.5": ["2.1.3"],
}
Args:
finding (Any): The Check_Report finding
provider_type (str): The provider type
bulk_checks_metadata (dict): The bulk checks metadata
Returns:
dict: The compliance framework as key and the requirements where the finding's check is present.
"""
try:
check_compliance = {}
# We have to retrieve all the check's compliance requirements
if finding.check_metadata.CheckID in bulk_checks_metadata:
for compliance in bulk_checks_metadata[
finding.check_metadata.CheckID
].Compliance:
compliance_fw = compliance.Framework
if compliance.Version:
compliance_fw = f"{compliance_fw}-{compliance.Version}"
# compliance.Provider == "Azure" or "Kubernetes"
# provider_type == "azure" or "kubernetes"
if compliance.Provider.upper() == provider_type.upper():
if compliance_fw not in check_compliance:
check_compliance[compliance_fw] = []
for requirement in compliance.Requirements:
check_compliance[compliance_fw].append(requirement.Id)
return check_compliance
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
return {}
@@ -0,0 +1,48 @@
from prowler.lib.check.models import Check_Report
from prowler.lib.logger import logger
# TODO: this should be in the Check class
def get_check_compliance(
finding: Check_Report, provider_type: str, bulk_checks_metadata: dict
) -> dict:
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
Example:
{
"CIS-1.4": ["2.1.3"],
"CIS-1.5": ["2.1.3"],
}
Args:
finding (Any): The Check_Report finding
provider_type (str): The provider type
bulk_checks_metadata (dict): The bulk checks metadata
Returns:
dict: The compliance framework as key and the requirements where the finding's check is present.
"""
try:
check_compliance = {}
# We have to retrieve all the check's compliance requirements
if finding.check_metadata.CheckID in bulk_checks_metadata:
for compliance in bulk_checks_metadata[
finding.check_metadata.CheckID
].Compliance:
compliance_fw = compliance.Framework
if compliance.Version:
compliance_fw = f"{compliance_fw}-{compliance.Version}"
# compliance.Provider == "Azure" or "Kubernetes"
# provider_type == "azure" or "kubernetes"
if compliance.Provider.upper() == provider_type.upper():
if compliance_fw not in check_compliance:
check_compliance[compliance_fw] = []
for requirement in compliance.Requirements:
check_compliance[compliance_fw].append(requirement.Id)
return check_compliance
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
)
return {}
@@ -0,0 +1,98 @@
from colorama import Fore, Style
from tabulate import tabulate
from prowler.config.config import orange_color
def get_essential_eight_table(
findings: list,
bulk_checks_metadata: dict,
compliance_framework: str,
output_filename: str,
output_directory: str,
compliance_overview: bool,
):
sections = {}
essential_eight_compliance_table = {
"Provider": [],
"Section": [],
"Status": [],
"Muted": [],
}
pass_count = []
fail_count = []
muted_count = []
for index, finding in enumerate(findings):
check = bulk_checks_metadata[finding.check_metadata.CheckID]
check_compliances = check.Compliance
for compliance in check_compliances:
if compliance.Framework == "Essential-Eight":
for requirement in compliance.Requirements:
for attribute in requirement.Attributes:
section = attribute.Section
if section not in sections:
sections[section] = {
"FAIL": 0,
"PASS": 0,
"Muted": 0,
}
if finding.muted:
if index not in muted_count:
muted_count.append(index)
sections[section]["Muted"] += 1
else:
if finding.status == "FAIL" and index not in fail_count:
fail_count.append(index)
sections[section]["FAIL"] += 1
elif finding.status == "PASS" and index not in pass_count:
pass_count.append(index)
sections[section]["PASS"] += 1
sections = dict(sorted(sections.items()))
for section in sections:
essential_eight_compliance_table["Provider"].append(compliance.Provider)
essential_eight_compliance_table["Section"].append(section)
if sections[section]["FAIL"] > 0:
essential_eight_compliance_table["Status"].append(
f"{Fore.RED}FAIL({sections[section]['FAIL']}){Style.RESET_ALL}"
)
elif sections[section]["PASS"] > 0:
essential_eight_compliance_table["Status"].append(
f"{Fore.GREEN}PASS({sections[section]['PASS']}){Style.RESET_ALL}"
)
else:
essential_eight_compliance_table["Status"].append("-")
essential_eight_compliance_table["Muted"].append(
f"{orange_color}{sections[section]['Muted']}{Style.RESET_ALL}"
)
if len(fail_count) + len(pass_count) + len(muted_count) > 1:
print(
f"\nCompliance Status of {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Framework:"
)
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
overview_table = [
[
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
]
]
print(tabulate(overview_table, tablefmt="rounded_grid"))
if not compliance_overview:
print(
f"\nFramework {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Results:"
)
print(
tabulate(
essential_eight_compliance_table,
headers="keys",
tablefmt="rounded_grid",
)
)
print(
f"{Style.BRIGHT}* Only sections containing results appear.{Style.RESET_ALL}"
)
print(f"\nDetailed results of {compliance_framework.upper()} are in:")
print(
f" - CSV: {output_directory}/compliance/{output_filename}_{compliance_framework}.csv\n"
)
@@ -0,0 +1,111 @@
from prowler.config.config import timestamp
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.outputs.compliance.compliance_output import ComplianceOutput
from prowler.lib.outputs.compliance.essential_eight.models import (
EssentialEightAWSModel,
)
from prowler.lib.outputs.finding import Finding
class EssentialEightAWS(ComplianceOutput):
"""
This class represents the AWS ASD Essential Eight compliance output.
Attributes:
- _data (list): A list to store transformed data from findings.
- _file_descriptor (TextIOWrapper): A file descriptor to write data to a file.
Methods:
- transform: Transforms findings into AWS Essential Eight compliance format.
"""
def transform(
self,
findings: list[Finding],
compliance: Compliance,
compliance_name: str,
) -> None:
"""
Transforms a list of findings into AWS Essential Eight compliance format.
Parameters:
- findings (list): A list of findings.
- compliance (Compliance): A compliance model.
- compliance_name (str): The name of the compliance model.
Returns:
- None
"""
for finding in findings:
finding_requirements = finding.compliance.get(compliance_name, [])
for requirement in compliance.Requirements:
if requirement.Id in finding_requirements:
for attribute in requirement.Attributes:
compliance_row = EssentialEightAWSModel(
Provider=finding.provider,
Description=compliance.Description,
AccountId=finding.account_uid,
Region=finding.region,
AssessmentDate=str(timestamp),
Requirements_Id=requirement.Id,
Requirements_Description=requirement.Description,
Requirements_Attributes_Section=attribute.Section,
Requirements_Attributes_MaturityLevel=attribute.MaturityLevel,
Requirements_Attributes_AssessmentStatus=attribute.AssessmentStatus,
Requirements_Attributes_CloudApplicability=attribute.CloudApplicability,
Requirements_Attributes_MitigatedThreats=", ".join(
attribute.MitigatedThreats
),
Requirements_Attributes_Description=attribute.Description,
Requirements_Attributes_RationaleStatement=attribute.RationaleStatement,
Requirements_Attributes_ImpactStatement=attribute.ImpactStatement,
Requirements_Attributes_RemediationProcedure=attribute.RemediationProcedure,
Requirements_Attributes_AuditProcedure=attribute.AuditProcedure,
Requirements_Attributes_AdditionalInformation=attribute.AdditionalInformation,
Requirements_Attributes_References=attribute.References,
Status=finding.status,
StatusExtended=finding.status_extended,
ResourceId=finding.resource_uid,
ResourceName=finding.resource_name,
CheckId=finding.check_id,
Muted=finding.muted,
Framework=compliance.Framework,
Name=compliance.Name,
)
self._data.append(compliance_row)
# Add manual requirements to the compliance output
for requirement in compliance.Requirements:
if not requirement.Checks:
for attribute in requirement.Attributes:
compliance_row = EssentialEightAWSModel(
Provider=compliance.Provider.lower(),
Description=compliance.Description,
AccountId="",
Region="",
AssessmentDate=str(timestamp),
Requirements_Id=requirement.Id,
Requirements_Description=requirement.Description,
Requirements_Attributes_Section=attribute.Section,
Requirements_Attributes_MaturityLevel=attribute.MaturityLevel,
Requirements_Attributes_AssessmentStatus=attribute.AssessmentStatus,
Requirements_Attributes_CloudApplicability=attribute.CloudApplicability,
Requirements_Attributes_MitigatedThreats=", ".join(
attribute.MitigatedThreats
),
Requirements_Attributes_Description=attribute.Description,
Requirements_Attributes_RationaleStatement=attribute.RationaleStatement,
Requirements_Attributes_ImpactStatement=attribute.ImpactStatement,
Requirements_Attributes_RemediationProcedure=attribute.RemediationProcedure,
Requirements_Attributes_AuditProcedure=attribute.AuditProcedure,
Requirements_Attributes_AdditionalInformation=attribute.AdditionalInformation,
Requirements_Attributes_References=attribute.References,
Status="MANUAL",
StatusExtended="Manual check",
ResourceId="manual_check",
ResourceName="Manual check",
CheckId="manual",
Muted=False,
Framework=compliance.Framework,
Name=compliance.Name,
)
self._data.append(compliance_row)
@@ -0,0 +1,35 @@
from pydantic.v1 import BaseModel
class EssentialEightAWSModel(BaseModel):
"""
EssentialEightAWSModel generates a finding's output in AWS ASD Essential Eight Compliance format.
"""
Provider: str
Description: str
AccountId: str
Region: str
AssessmentDate: str
Requirements_Id: str
Requirements_Description: str
Requirements_Attributes_Section: str
Requirements_Attributes_MaturityLevel: str
Requirements_Attributes_AssessmentStatus: str
Requirements_Attributes_CloudApplicability: str
Requirements_Attributes_MitigatedThreats: str
Requirements_Attributes_Description: str
Requirements_Attributes_RationaleStatement: str
Requirements_Attributes_ImpactStatement: str
Requirements_Attributes_RemediationProcedure: str
Requirements_Attributes_AuditProcedure: str
Requirements_Attributes_AdditionalInformation: str
Requirements_Attributes_References: str
Status: str
StatusExtended: str
ResourceId: str
ResourceName: str
CheckId: str
Muted: bool
Framework: str
Name: str
@@ -1,6 +1,7 @@
import json
import os
from datetime import datetime
from typing import List
from typing import TYPE_CHECKING, List
from py_ocsf_models.events.base_event import SeverityID
from py_ocsf_models.events.base_event import StatusID as EventStatusID
@@ -20,11 +21,12 @@ from py_ocsf_models.objects.resource_details import ResourceDetails
from prowler.config.config import prowler_version
from prowler.lib.check.compliance_models import ComplianceFramework
from prowler.lib.logger import logger
from prowler.lib.outputs.finding import Finding
from prowler.lib.outputs.ocsf.ocsf import OCSF
from prowler.lib.outputs.utils import unroll_dict_to_list
from prowler.lib.utils.utils import open_file
if TYPE_CHECKING:
from prowler.lib.outputs.finding import Finding
PROWLER_TO_COMPLIANCE_STATUS = {
"PASS": ComplianceStatusID.Pass,
"FAIL": ComplianceStatusID.Fail,
@@ -32,6 +34,40 @@ PROWLER_TO_COMPLIANCE_STATUS = {
}
def _sanitize_resource_data(resource_details, resource_metadata) -> dict:
"""Ensure resource data is JSON-serializable.
Service resource_metadata may carry non-serializable objects (e.g. raw
Pydantic models or service classes such as ``Trail`` / ``LifecyclePolicy``).
Convert them to plain dicts and roundtrip through JSON so the resulting
ComplianceFinding can be serialized without errors.
"""
def _make_serializable(obj):
if hasattr(obj, "model_dump") and callable(obj.model_dump):
return _make_serializable(obj.model_dump())
if hasattr(obj, "dict") and callable(obj.dict):
return _make_serializable(obj.dict())
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_make_serializable(v) for v in obj]
return obj
try:
converted = _make_serializable(resource_metadata)
sanitized_metadata = json.loads(json.dumps(converted, default=str))
except (TypeError, ValueError, RecursionError) as error:
logger.warning(
f"Failed to serialize resource metadata, defaulting to empty: {error}"
)
sanitized_metadata = {}
return {
"details": resource_details,
"metadata": sanitized_metadata,
}
def _to_snake_case(name: str) -> str:
"""Convert a PascalCase or camelCase string to snake_case."""
import re
@@ -108,7 +144,7 @@ class OCSFComplianceOutput:
def _transform(
self,
findings: List[Finding],
findings: List["Finding"],
framework: ComplianceFramework,
compliance_name: str,
) -> None:
@@ -177,7 +213,7 @@ class OCSFComplianceOutput:
def _build_compliance_finding(
self,
finding: Finding,
finding: "Finding",
framework: ComplianceFramework,
requirement,
compliance_name: str,
@@ -195,7 +231,9 @@ class OCSFComplianceOutput:
finding.metadata.Severity.capitalize(),
SeverityID.Unknown,
)
event_status = OCSF.get_finding_status_id(finding.muted)
event_status = (
EventStatusID.Suppressed if finding.muted else EventStatusID.New
)
time_value = (
int(finding.timestamp.timestamp())
@@ -268,10 +306,10 @@ class OCSFComplianceOutput:
if finding.provider == "kubernetes"
else None
),
data={
"details": finding.resource_details,
"metadata": finding.resource_metadata,
},
data=_sanitize_resource_data(
finding.resource_details,
finding.resource_metadata,
),
)
],
severity_id=finding_severity.value,
@@ -0,0 +1,294 @@
from csv import DictWriter
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from pydantic.v1 import create_model
from prowler.config.config import timestamp
from prowler.lib.check.compliance_models import ComplianceFramework
from prowler.lib.logger import logger
from prowler.lib.utils.utils import open_file
if TYPE_CHECKING:
from prowler.lib.outputs.finding import Finding
PROVIDER_HEADER_MAP = {
"aws": ("AccountId", "account_uid", "Region", "region"),
"azure": ("SubscriptionId", "account_uid", "Location", "region"),
"gcp": ("ProjectId", "account_uid", "Location", "region"),
"kubernetes": ("Context", "account_name", "Namespace", "region"),
"m365": ("TenantId", "account_uid", "Location", "region"),
"github": ("Account_Name", "account_name", "Account_Id", "account_uid"),
"oraclecloud": ("TenancyId", "account_uid", "Region", "region"),
"alibabacloud": ("AccountId", "account_uid", "Region", "region"),
"nhn": ("AccountId", "account_uid", "Region", "region"),
}
_DEFAULT_HEADERS = ("AccountId", "account_uid", "Region", "region")
class UniversalComplianceOutput:
"""Universal compliance CSV output driven by ComplianceFramework metadata.
Dynamically builds a Pydantic row model from attributes_metadata so that
CSV columns match the framework's declared attribute fields.
"""
def __init__(
self,
findings: list,
framework: ComplianceFramework,
file_path: str = None,
from_cli: bool = True,
provider: str = None,
) -> None:
self._data = []
self._file_descriptor = None
self.file_path = file_path
self._from_cli = from_cli
self._provider = provider
self.close_file = False
if file_path:
path_obj = Path(file_path)
self._file_extension = path_obj.suffix if path_obj.suffix else ""
if findings:
self._row_model = self._build_row_model(framework)
compliance_name = (
framework.framework + "-" + framework.version
if framework.version
else framework.framework
)
self._transform(findings, framework, compliance_name)
if not self._file_descriptor and file_path:
self._create_file_descriptor(file_path)
@property
def data(self):
return self._data
def _build_row_model(self, framework: ComplianceFramework):
"""Build a dynamic Pydantic model from attributes_metadata."""
acct_header, acct_field, loc_header, loc_field = PROVIDER_HEADER_MAP.get(
(self._provider or "").lower(), _DEFAULT_HEADERS
)
self._acct_header = acct_header
self._acct_field = acct_field
self._loc_header = loc_header
self._loc_field = loc_field
# Base fields present in every compliance CSV
fields = {
"Provider": (str, ...),
"Description": (str, ...),
acct_header: (str, ...),
loc_header: (str, ...),
"AssessmentDate": (str, ...),
"Requirements_Id": (str, ...),
"Requirements_Description": (str, ...),
}
# Dynamic attribute columns from metadata
if framework.attributes_metadata:
for attr_meta in framework.attributes_metadata:
if not attr_meta.output_formats.csv:
continue
field_name = f"Requirements_Attributes_{attr_meta.key}"
# Map type strings to Python types
type_map = {
"str": Optional[str],
"int": Optional[int],
"float": Optional[float],
"bool": Optional[bool],
"list_str": Optional[str], # Serialized as joined string
"list_dict": Optional[str], # Serialized as string
}
py_type = type_map.get(attr_meta.type, Optional[str])
fields[field_name] = (py_type, None)
# Check if any requirement has MITRE fields
has_mitre = any(req.tactics for req in framework.requirements if req.tactics)
if has_mitre:
fields["Requirements_Tactics"] = (Optional[str], None)
fields["Requirements_SubTechniques"] = (Optional[str], None)
fields["Requirements_Platforms"] = (Optional[str], None)
fields["Requirements_TechniqueURL"] = (Optional[str], None)
# Trailing fields
fields["Status"] = (str, ...)
fields["StatusExtended"] = (str, ...)
fields["ResourceId"] = (str, ...)
fields["ResourceName"] = (str, ...)
fields["CheckId"] = (str, ...)
fields["Muted"] = (bool, ...)
fields["Framework"] = (str, ...)
fields["Name"] = (str, ...)
return create_model("UniversalComplianceRow", **fields)
def _serialize_attr_value(self, value):
"""Serialize attribute values for CSV."""
if isinstance(value, list):
if value and isinstance(value[0], dict):
return str(value)
return " | ".join(str(v) for v in value)
return value
def _build_row(self, finding, framework, requirement, is_manual=False):
"""Build a single row dict for a finding + requirement combination."""
row = {
"Provider": (
finding.provider
if not is_manual
else (framework.provider or self._provider or "").lower()
),
"Description": framework.description,
self._acct_header: (
getattr(finding, self._acct_field, "") if not is_manual else ""
),
self._loc_header: (
getattr(finding, self._loc_field, "") if not is_manual else ""
),
"AssessmentDate": str(timestamp),
"Requirements_Id": requirement.id,
"Requirements_Description": requirement.description,
}
# Add dynamic attribute columns
if framework.attributes_metadata:
for attr_meta in framework.attributes_metadata:
if not attr_meta.output_formats.csv:
continue
field_name = f"Requirements_Attributes_{attr_meta.key}"
raw_val = requirement.attributes.get(attr_meta.key)
row[field_name] = (
self._serialize_attr_value(raw_val) if raw_val is not None else None
)
# MITRE fields
if requirement.tactics:
row["Requirements_Tactics"] = (
" | ".join(requirement.tactics) if requirement.tactics else None
)
row["Requirements_SubTechniques"] = (
" | ".join(requirement.sub_techniques)
if requirement.sub_techniques
else None
)
row["Requirements_Platforms"] = (
" | ".join(requirement.platforms) if requirement.platforms else None
)
row["Requirements_TechniqueURL"] = requirement.technique_url
row["Status"] = finding.status if not is_manual else "MANUAL"
row["StatusExtended"] = (
finding.status_extended if not is_manual else "Manual check"
)
row["ResourceId"] = finding.resource_uid if not is_manual else "manual_check"
row["ResourceName"] = finding.resource_name if not is_manual else "Manual check"
row["CheckId"] = finding.check_id if not is_manual else "manual"
row["Muted"] = finding.muted if not is_manual else False
row["Framework"] = framework.framework
row["Name"] = framework.name
return row
def _transform(
self,
findings: list["Finding"],
framework: ComplianceFramework,
compliance_name: str,
) -> None:
"""Transform findings into universal compliance CSV rows."""
# Build check -> requirements map (filtered by provider for dict checks)
check_req_map = {}
for req in framework.requirements:
checks = req.checks
if self._provider:
all_checks = checks.get(self._provider.lower(), [])
else:
all_checks = []
for check_list in checks.values():
all_checks.extend(check_list)
for check_id in all_checks:
if check_id not in check_req_map:
check_req_map[check_id] = []
check_req_map[check_id].append(req)
# Process findings using the provider-filtered check_req_map.
# This ensures that for multi-provider dict checks, only the checks
# belonging to the current provider produce output rows.
for finding in findings:
check_id = finding.check_id
if check_id in check_req_map:
for req in check_req_map[check_id]:
row = self._build_row(finding, framework, req)
try:
self._data.append(self._row_model(**row))
except Exception as e:
logger.debug(f"Skipping row for {req.id}: {e}")
# Manual requirements (no checks or empty dict)
for req in framework.requirements:
checks = req.checks
if self._provider:
has_checks = bool(checks.get(self._provider.lower(), []))
else:
has_checks = any(checks.values())
if not has_checks:
# Use a dummy finding-like namespace for manual rows
row = self._build_row(
_ManualFindingStub(), framework, req, is_manual=True
)
try:
self._data.append(self._row_model(**row))
except Exception as e:
logger.debug(f"Skipping manual row for {req.id}: {e}")
def _create_file_descriptor(self, file_path: str) -> None:
try:
self._file_descriptor = open_file(file_path, "a")
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def batch_write_data_to_file(self) -> None:
"""Write findings data to CSV."""
try:
if (
getattr(self, "_file_descriptor", None)
and not self._file_descriptor.closed
and self._data
):
csv_writer = DictWriter(
self._file_descriptor,
fieldnames=[field.upper() for field in self._data[0].dict().keys()],
delimiter=";",
)
if self._file_descriptor.tell() == 0:
csv_writer.writeheader()
for row in self._data:
csv_writer.writerow({k.upper(): v for k, v in row.dict().items()})
if self.close_file or self._from_cli:
self._file_descriptor.close()
except Exception as error:
logger.error(
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
class _ManualFindingStub:
"""Minimal stub to satisfy _build_row for manual requirements."""
provider = ""
account_uid = ""
account_name = ""
region = ""
status = "MANUAL"
status_extended = "Manual check"
resource_uid = "manual_check"
resource_name = "Manual check"
check_id = "manual"
muted = False
+1 -1
View File
@@ -15,7 +15,7 @@ from prowler.lib.check.models import (
)
from prowler.lib.logger import logger
from prowler.lib.outputs.common import Status, fill_common_finding_data
from prowler.lib.outputs.compliance.compliance import get_check_compliance
from prowler.lib.outputs.compliance.compliance_check import get_check_compliance
from prowler.lib.outputs.utils import unroll_tags
from prowler.lib.utils.utils import dict_to_lowercase, get_nested_attribute
from prowler.providers.common.provider import Provider
+37 -18
View File
@@ -25,8 +25,8 @@ from prowler.lib.utils.utils import open_file, parse_json_file, print_boxes
from prowler.providers.aws.config import (
AWS_REGION_US_EAST_1,
AWS_STS_GLOBAL_ENDPOINT_REGION,
BOTO3_USER_AGENT_EXTRA,
ROLE_SESSION_NAME,
get_default_session_config,
)
from prowler.providers.aws.exceptions.exceptions import (
AWSAccessKeyIDInvalidError,
@@ -227,14 +227,15 @@ class AwsProvider(Provider):
# TODO: Use AwsSetUpSession ?????
# Configure the initial AWS Session using the local credentials: profile or environment variables
session_config = self.set_session_config(retries_max_attempts)
aws_session = self.setup_session(
mfa=mfa,
profile=profile,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
session_config=session_config,
)
session_config = self.set_session_config(retries_max_attempts)
# Current session and the original session points to the same session object until we get a new one, if needed
self._session = AWSSession(
current_session=aws_session,
@@ -630,6 +631,7 @@ class AwsProvider(Provider):
aws_access_key_id: str = None,
aws_secret_access_key: str = None,
aws_session_token: Optional[str] = None,
session_config: Optional[Config] = None,
) -> Session:
"""
setup_session sets up an AWS session using the provided credentials.
@@ -640,6 +642,9 @@ class AwsProvider(Provider):
- aws_access_key_id: The AWS access key ID.
- aws_secret_access_key: The AWS secret access key.
- aws_session_token: The AWS session token, optional.
- session_config: Botocore Config applied as the session's default
client config so every client created from the session inherits
the Prowler user agent and retry settings.
Returns:
- Session: The AWS session.
@@ -650,6 +655,9 @@ class AwsProvider(Provider):
try:
logger.debug("Creating original session ...")
if session_config is None:
session_config = AwsProvider.set_session_config(None)
session_arguments = {}
if profile:
session_arguments["profile_name"] = profile
@@ -661,6 +669,7 @@ class AwsProvider(Provider):
if mfa:
session = Session(**session_arguments)
session._session.set_default_client_config(session_config)
sts_client = session.client("sts")
# TODO: pass values from the input
@@ -673,7 +682,7 @@ class AwsProvider(Provider):
session_credentials = sts_client.get_session_token(
**get_session_token_arguments
)
return Session(
mfa_session = Session(
aws_access_key_id=session_credentials["Credentials"]["AccessKeyId"],
aws_secret_access_key=session_credentials["Credentials"][
"SecretAccessKey"
@@ -682,8 +691,12 @@ class AwsProvider(Provider):
"SessionToken"
],
)
mfa_session._session.set_default_client_config(session_config)
return mfa_session
else:
return Session(**session_arguments)
session = Session(**session_arguments)
session._session.set_default_client_config(session_config)
return session
except Exception as error:
logger.critical(
f"AWSSetUpSessionError[{error.__traceback__.tb_lineno}]: {error}"
@@ -698,6 +711,7 @@ class AwsProvider(Provider):
identity: AWSIdentityInfo,
assumed_role_configuration: AWSAssumeRoleConfiguration,
session: AWSSession,
session_config: Optional[Config] = None,
) -> Session:
"""
Sets up an assumed session using the provided assumed role credentials.
@@ -742,6 +756,13 @@ class AwsProvider(Provider):
assumed_session = BotocoreSession()
assumed_session._credentials = assumed_refreshable_credentials
assumed_session.set_config_variable("region", identity.profile_region)
if session_config is None:
session_config = (
session.session_config
if session is not None
else AwsProvider.set_session_config(None)
)
assumed_session.set_default_client_config(session_config)
return Session(
profile_name=identity.profile,
botocore_session=assumed_session,
@@ -870,7 +891,7 @@ class AwsProvider(Provider):
for region in enabled_regions:
regional_client = self._session.current_session.client(
service, region_name=region, config=self._session.session_config
service, region_name=region
)
regional_client.region = region
regional_clients[region] = regional_client
@@ -1140,21 +1161,16 @@ class AwsProvider(Provider):
Returns:
- Config: The botocore Config object
"""
# Set the maximum retries for the standard retrier config
default_session_config = Config(
retries={"max_attempts": 3, "mode": "standard"},
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
)
default_session_config = get_default_session_config()
if retries_max_attempts:
# Create the new config
config = Config(
retries={
"max_attempts": retries_max_attempts,
"mode": "standard",
},
default_session_config = default_session_config.merge(
Config(
retries={
"max_attempts": retries_max_attempts,
"mode": "standard",
},
)
)
# Merge the new configuration
default_session_config = default_session_config.merge(config)
return default_session_config
@@ -1425,6 +1441,9 @@ class AwsProvider(Provider):
region_name=aws_region,
profile_name=profile,
)
session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
caller_identity = AwsProvider.validate_credentials(session, aws_region)
# Do an extra validation if the AWS account ID is provided
+9
View File
@@ -1,6 +1,15 @@
import os
from botocore.config import Config
AWS_STS_GLOBAL_ENDPOINT_REGION = "us-east-1"
AWS_REGION_US_EAST_1 = "us-east-1"
BOTO3_USER_AGENT_EXTRA = os.getenv("PROWLER_AWS_BOTO3_USER_AGENT_EXTRA", "APN_1826889")
ROLE_SESSION_NAME = "ProwlerAssessmentSession"
def get_default_session_config() -> Config:
return Config(
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
retries={"max_attempts": 3, "mode": "standard"},
)
@@ -56,9 +56,7 @@ def quick_inventory(provider: AwsProvider, args):
try:
# Scan IAM only once
if not iam_was_scanned:
global_resources.extend(
get_iam_resources(provider.session.current_session)
)
global_resources.extend(get_iam_resources(provider))
iam_was_scanned = True
# Get regional S3 buckets since none-tagged buckets are not supported by the resourcegroupstaggingapi
@@ -312,8 +310,8 @@ def create_output(resources: list, provider: AwsProvider, args):
if args.output_bucket:
output_bucket = args.output_bucket
bucket_session = provider.session.current_session
# Check if -D was input
elif args.output_bucket_no_assume:
# The outer condition guarantees -D was input when -B was not
else:
output_bucket = args.output_bucket_no_assume
bucket_session = provider.session.original_session
@@ -375,9 +373,9 @@ def get_regional_buckets(provider: AwsProvider, region: str) -> list:
return regional_buckets
def get_iam_resources(session) -> list:
def get_iam_resources(provider: AwsProvider) -> list:
iam_resources = []
iam_client = session.client("iam")
iam_client = provider.session.current_session.client("iam")
try:
get_roles_paginator = iam_client.get_paginator("list_roles")
for page in get_roles_paginator.paginate():
+11 -2
View File
@@ -111,6 +111,13 @@ class S3:
- None
"""
if session:
# Preserve the caller's existing default config (and the
# retries_max_attempts already baked into it) instead of clobbering
# it with a freshly built one.
if session._session.get_default_client_config() is None:
session._session.set_default_client_config(
AwsProvider.set_session_config(retries_max_attempts)
)
self._session = session.client(__class__.__name__.lower())
else:
aws_setup_session = AwsSetUpSession(
@@ -127,8 +134,7 @@ class S3:
regions=regions,
)
self._session = aws_setup_session._session.current_session.client(
__class__.__name__.lower(),
config=aws_setup_session._session.session_config,
__class__.__name__.lower()
)
self._bucket_name = bucket_name
@@ -313,6 +319,9 @@ class S3:
region_name=aws_region,
profile_name=profile,
)
session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
s3_client = session.client(__class__.__name__.lower())
if "s3://" in bucket_name:
bucket_name = bucket_name.removeprefix("s3://")
@@ -148,6 +148,13 @@ class SecurityHub:
regions=regions,
)
self._session = aws_setup_session._session.current_session
# Only install the Prowler default config when the caller-supplied
# session does not already carry one — overwriting would drop the
# provider's retries_max_attempts value.
if aws_session and self._session._session.get_default_client_config() is None:
self._session._session.set_default_client_config(
AwsProvider.set_session_config(retries_max_attempts)
)
self._aws_account_id = aws_account_id
if not aws_partition:
aws_partition = AwsProvider.validate_credentials(
@@ -235,7 +242,7 @@ class SecurityHub:
Args:
region (str): AWS region to check.
session (Session): AWS session object.
session (Session): AWS session object. Expected to carry the Prowler default client config.
aws_account_id (str): AWS account ID.
aws_partition (str): AWS partition.
@@ -540,6 +547,9 @@ class SecurityHub:
region_name=aws_region,
profile_name=profile,
)
session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
all_regions = AwsProvider.get_available_aws_service_regions(
service="securityhub", partition=aws_partition
+8 -2
View File
@@ -32,7 +32,13 @@ class AWSService:
def is_failed_check(cls, check_id, arn):
return (check_id.split(".")[-1], arn) in cls.failed_checks
def __init__(self, service: str, provider: AwsProvider, global_service=False):
def __init__(
self,
service: str,
provider: AwsProvider,
global_service=False,
region: str = None,
):
# Audit Information
# Do we need to store the whole provider?
self.provider = provider
@@ -61,7 +67,7 @@ class AWSService:
# Get a single region and client if the service needs it (e.g. AWS Global Service)
# We cannot include this within an else because some services needs both the regional_clients
# and a single client like S3
self.region = provider.get_default_region(
self.region = region or provider.get_default_region(
self.service, global_service=global_service
)
self.client = self.session.client(self.service, self.region)
@@ -73,15 +73,15 @@ class AwsSetUpSession:
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
# Setup the AWS session
session_config = AwsProvider.set_session_config(retries_max_attempts)
aws_session = AwsProvider.setup_session(
mfa=mfa,
profile=profile,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
session_config=session_config,
)
session_config = AwsProvider.set_session_config(retries_max_attempts)
self._session = AWSSession(
current_session=aws_session,
session_config=session_config,
@@ -1,4 +1,5 @@
import datetime
from concurrent.futures import as_completed
from typing import List, Optional
from pydantic.v1 import BaseModel
@@ -14,9 +15,9 @@ class Codebuild(AWSService):
super().__init__(__class__.__name__, provider)
self.projects = {}
self.__threading_call__(self._list_projects)
self.__threading_call__(self._list_builds_for_project, self.projects.values())
self.__threading_call__(self._batch_get_builds, self.projects.values())
self.__threading_call__(self._batch_get_projects, self.projects.values())
self.__threading_call__(self._list_builds_for_project)
self.__threading_call__(self._batch_get_builds)
self.__threading_call__(self._batch_get_projects)
self.report_groups = {}
self.__threading_call__(self._list_report_groups)
self.__threading_call__(
@@ -44,10 +45,8 @@ class Codebuild(AWSService):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _list_builds_for_project(self, project):
logger.info("Codebuild - Listing builds...")
def _fetch_project_last_build(self, regional_client, project):
try:
regional_client = self.regional_clients[project.region]
build_ids = regional_client.list_builds_for_project(
projectName=project.name
).get("ids", [])
@@ -58,28 +57,99 @@ class Codebuild(AWSService):
f"{project.region}: {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _batch_get_builds(self, project):
logger.info("Codebuild - Getting builds...")
def _list_builds_for_project(self, regional_client):
logger.info("Codebuild - Listing builds...")
try:
if project.last_build and project.last_build.id:
regional_client = self.regional_clients[project.region]
builds_by_id = regional_client.batch_get_builds(
ids=[project.last_build.id]
).get("builds", [])
if len(builds_by_id) > 0:
project.last_invoked_time = builds_by_id[0].get("endTime")
regional_projects = [
project
for project in self.projects.values()
if project.region == regional_client.region
]
# list_builds_for_project has no batch API equivalent, so reuse the
# shared thread pool to issue per-project calls in parallel within
# this region — preserving the wall-clock performance of the
# previous implementation.
futures = [
self.thread_pool.submit(
self._fetch_project_last_build, regional_client, project
)
for project in regional_projects
]
for future in as_completed(futures):
try:
future.result()
except Exception:
pass
except Exception as error:
logger.error(
f"{regional_client.region}: {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _batch_get_projects(self, project):
def _batch_get_builds(self, regional_client):
logger.info("Codebuild - Getting builds...")
try:
# Collect all build IDs for this region
build_id_to_project = {}
for project in self.projects.values():
if (
project.region == regional_client.region
and project.last_build
and project.last_build.id
):
build_id_to_project[project.last_build.id] = project
if not build_id_to_project:
return
build_ids = list(build_id_to_project.keys())
# batch_get_builds supports up to 100 IDs per call
for i in range(0, len(build_ids), 100):
batch = build_ids[i : i + 100]
response = regional_client.batch_get_builds(ids=batch)
for build_info in response.get("builds", []):
build_id = build_info.get("id")
if build_id in build_id_to_project:
end_time = build_info.get("endTime")
if end_time:
build_id_to_project[build_id].last_invoked_time = end_time
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _batch_get_projects(self, regional_client):
logger.info("Codebuild - Getting projects...")
try:
regional_client = self.regional_clients[project.region]
project_info = regional_client.batch_get_projects(names=[project.name])[
"projects"
][0]
# Collect all project names for this region
regional_projects = {
arn: project
for arn, project in self.projects.items()
if project.region == regional_client.region
}
if not regional_projects:
return
project_names = [project.name for project in regional_projects.values()]
# batch_get_projects supports up to 100 names per call
for i in range(0, len(project_names), 100):
batch = project_names[i : i + 100]
response = regional_client.batch_get_projects(names=batch)
for project_info in response.get("projects", []):
project_arn = project_info.get("arn")
if project_arn in regional_projects:
self._parse_project_info(
regional_projects[project_arn], project_info
)
except Exception as error:
logger.error(
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _parse_project_info(self, project, project_info):
try:
project.buildspec = project_info["source"].get("buildspec")
if project_info["source"]["type"] != "NO_SOURCE":
project.source = Source(
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
class GlobalAccelerator(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
# That is, for example, specify --region us-west-2 on AWS CLI commands.
region = "us-west-2" if provider.identity.partition == "aws" else None
super().__init__(__class__.__name__, provider, region=region)
self.accelerators = {}
if self.audited_partition == "aws":
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
# That is, for example, specify --region us-west-2 on AWS CLI commands.
self.region = "us-west-2"
self.client = self.session.client(self.service, self.region)
self._list_accelerators()
self.__threading_call__(self._list_tags, self.accelerators.values())
@@ -176,14 +176,12 @@ class RecordSet(BaseModel):
class Route53Domains(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
region = "us-east-1" if provider.identity.partition == "aws" else None
super().__init__(__class__.__name__, provider, region=region)
self.domains = {}
if self.audited_partition == "aws":
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
self.region = "us-east-1"
self.client = self.session.client(self.service, self.region)
self._list_domains()
self._get_domain_detail()
self._list_tags_for_domain()
@@ -9,20 +9,20 @@ from prowler.providers.aws.lib.service.service import AWSService
class TrustedAdvisor(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__("support", provider)
# Support API is not available in China Partition
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
partition = provider.identity.partition
if partition == "aws":
support_region = "us-east-1"
elif partition == "aws-cn":
support_region = None
else:
support_region = "us-gov-west-1"
super().__init__("support", provider, region=support_region)
self.account_arn_template = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:account"
self.checks = []
self.premium_support = PremiumSupport(enabled=False)
# Support API is not available in China Partition
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
if self.audited_partition != "aws-cn":
if self.audited_partition == "aws":
support_region = "us-east-1"
else:
support_region = "us-gov-west-1"
self.client = self.session.client(self.service, region_name=support_region)
self.client.region = support_region
self._describe_services()
if getattr(self.premium_support, "enabled", False):
self._describe_trusted_advisor_checks()
@@ -34,13 +34,13 @@ class TrustedAdvisor(AWSService):
for check in self.client.describe_trusted_advisor_checks(language="en").get(
"checks", []
):
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.client.region}:{self.audited_account}:check/{check['id']}"
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:check/{check['id']}"
self.checks.append(
Check(
id=check["id"],
name=check["name"],
arn=check_arn,
region=self.client.region,
region=self.region,
)
)
except ClientError as error:
@@ -50,22 +50,22 @@ class TrustedAdvisor(AWSService):
== "Amazon Web Services Premium Support Subscription is required to use this service."
):
logger.warning(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
else:
logger.error(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _describe_trusted_advisor_check_result(self):
logger.info("TrustedAdvisor - Describing Check Result...")
try:
for check in self.checks:
if check.region == self.client.region:
if check.region == self.region:
try:
response = self.client.describe_trusted_advisor_check_result(
checkId=check.id
@@ -78,11 +78,11 @@ class TrustedAdvisor(AWSService):
== "InvalidParameterValueException"
):
logger.warning(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
except Exception as error:
logger.error(
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
def _describe_services(self):
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
class WAF(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__("waf", provider)
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
region = "us-east-1" if provider.identity.partition == "aws" else None
super().__init__("waf", provider, region=region)
self.rules = {}
self.rule_groups = {}
self.web_acls = {}
if self.audited_partition == "aws":
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
self.region = "us-east-1"
self.client = self.session.client(self.service, self.region)
self._list_rules()
self.__threading_call__(self._get_rule, self.rules.values())
self._list_rule_groups()
@@ -11,13 +11,11 @@ from prowler.providers.aws.lib.service.service import AWSService
class WAFv2(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
super().__init__(__class__.__name__, provider)
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
region = "us-east-1" if provider.identity.partition == "aws" else None
super().__init__(__class__.__name__, provider, region=region)
self.web_acls = {}
if self.audited_partition == "aws":
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
self.region = "us-east-1"
self.client = self.session.client(self.service, self.region)
self._list_web_acls_global()
self.__threading_call__(self._list_web_acls_regional)
self.__threading_call__(self._get_web_acl, self.web_acls.values())
+27
View File
@@ -436,6 +436,33 @@ class Test_Config:
assert "csa_ccm_4.0" in aws_frameworks
assert "csa_ccm_4.0" not in kubernetes_frameworks
def test_get_available_compliance_frameworks_no_provider_includes_universals(self):
"""Regression test for the variable shadowing bug.
Previously, the inner ``for provider in providers`` loop shadowed
the outer ``provider`` parameter. When called without a provider,
the post-loop ``if provider:`` branch wrongly applied
``framework.supports_provider(<last provider iterated>)`` and
excluded universal frameworks from the result.
Result: the parser-level ``available_compliance_frameworks``
constant was missing universal frameworks like ``csa_ccm_4.0``,
which made ``--compliance csa_ccm_4.0`` reject the choice.
"""
all_frameworks = get_available_compliance_frameworks()
assert "csa_ccm_4.0" in all_frameworks
def test_get_available_compliance_frameworks_does_not_mutate_provider_param(self):
"""Calling with a specific provider must not affect a subsequent
call without provider. Validates that the loop variable rename
prevents leaking state between calls."""
# Force an iteration over multiple providers first
get_available_compliance_frameworks("kubernetes")
# Then a no-provider call must still include universals supported
# by ANY provider (not filtered by some leaked value)
all_frameworks = get_available_compliance_frameworks()
assert "csa_ccm_4.0" in all_frameworks
def test_load_and_validate_config_file_aws(self):
path = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
config_test_file = f"{path}/fixtures/config.yaml"
+174
View File
@@ -675,3 +675,177 @@ class TestCheckLoader:
)
assert CLOUDTRAIL_THREAT_DETECTION_ENUMERATION_NAME not in result
assert S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME in result
def test_load_checks_to_execute_universal_framework_takes_precedence(self):
"""When ``--compliance <fw>`` matches a universal framework, the
loader must source checks from ``universal_frameworks[fw].requirements[*]
.checks[provider]`` and NOT fall through to ``bulk_compliance_frameworks``.
This is the path added by PR #10301 in checks_loader.py.
"""
from prowler.lib.check.compliance_models import (
ComplianceFramework,
UniversalComplianceRequirement,
)
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
universal_framework = ComplianceFramework(
framework="csa_ccm",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="A&A-01",
description="Audit & Assurance",
attributes={},
checks={"aws": [S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME]},
),
],
)
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks={}, # legacy empty
compliance_frameworks=["csa_ccm_4.0"],
provider=self.provider,
universal_frameworks={"csa_ccm_4.0": universal_framework},
)
assert result == {S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME}
def test_load_checks_to_execute_universal_filters_by_provider(self):
"""A universal requirement may declare checks for several
providers; the loader must only return those for the active
provider key (lowercased)."""
from prowler.lib.check.compliance_models import (
ComplianceFramework,
UniversalComplianceRequirement,
)
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
# The same requirement maps a different check per provider.
# Only the AWS one must be returned for provider="aws".
universal_framework = ComplianceFramework(
framework="csa_ccm",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="A&A-02",
description="Multi-provider req",
attributes={},
checks={
"aws": [S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME],
"azure": ["azure_only_check"],
"gcp": ["gcp_only_check"],
},
),
],
)
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks={},
compliance_frameworks=["csa_ccm_4.0"],
provider=self.provider, # "aws"
universal_frameworks={"csa_ccm_4.0": universal_framework},
)
assert S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME in result
assert "azure_only_check" not in result
assert "gcp_only_check" not in result
def test_load_checks_to_execute_universal_no_match_falls_back_to_legacy(self):
"""If the requested compliance framework is not present in
``universal_frameworks``, the loader must fall back to the
legacy ``bulk_compliance_frameworks`` lookup."""
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
bulk_compliance_frameworks = {
"soc2_aws": Compliance(
Framework="SOC2",
Name="SOC2",
Provider="aws",
Version="2.0",
Description="x",
Requirements=[
Compliance_Requirement(
Checks=[S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME],
Id="",
Description="",
Attributes=[],
)
],
),
}
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks=bulk_compliance_frameworks,
compliance_frameworks=["soc2_aws"],
provider=self.provider,
universal_frameworks={"some_other_universal_fw": object()},
)
assert result == {S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME}
def test_load_checks_to_execute_universal_unknown_provider_returns_empty(self):
"""If the universal requirement has no checks for the active
provider, no checks are picked up for that requirement."""
from prowler.lib.check.compliance_models import (
ComplianceFramework,
UniversalComplianceRequirement,
)
bulk_checks_metadata = {
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
}
universal_framework = ComplianceFramework(
framework="csa_ccm",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="A&A-03",
description="Only Azure",
attributes={},
checks={"azure": ["azure_only_check"]},
),
],
)
with patch(
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
return_value=bulk_checks_metadata,
):
result = load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks={},
compliance_frameworks=["csa_ccm_4.0"],
provider=self.provider, # "aws" — no checks declared
universal_frameworks={"csa_ccm_4.0": universal_framework},
)
assert result == set()
@@ -442,3 +442,123 @@ class TestComplianceOutput:
)
assert compliance_output.file_extension == ".csv"
class TestComplianceCheckHelperModule:
"""Tests for the new ``compliance_check`` leaf module that hosts
``get_check_compliance``.
This module exists to break the cyclic import chain
``finding -> compliance.compliance -> universal.* -> finding`` that
CodeQL flagged. It must be:
- importable directly without pulling in the universal pipeline
- re-exported by ``compliance.compliance`` for backward compatibility
- the SAME function object, regardless of import path
"""
def test_module_is_importable_directly(self):
"""The helper module must be importable on its own — it is the
leaf used by ``finding.py`` to break the cyclic import chain."""
from prowler.lib.outputs.compliance import compliance_check
assert hasattr(compliance_check, "get_check_compliance")
assert callable(compliance_check.get_check_compliance)
def test_helper_module_only_depends_on_check_models_and_logger(self):
"""The helper must not pull in universal pipeline modules; that
was the whole point of extracting it. Inspecting the module's
own imports keeps it honest without polluting ``sys.modules``."""
import inspect
from prowler.lib.outputs.compliance import compliance_check
source = inspect.getsource(compliance_check)
# Only these two prowler imports are allowed in the leaf module
assert "from prowler.lib.check.models import Check_Report" in source
assert "from prowler.lib.logger import logger" in source
# And NOT these (would re-introduce the cycle):
assert "from prowler.lib.outputs.compliance.universal" not in source
assert "from prowler.lib.outputs.finding" not in source
assert "from prowler.lib.outputs.ocsf" not in source
def test_re_export_from_compliance_compliance(self):
"""``compliance.compliance.get_check_compliance`` must point to
the same function as ``compliance.compliance_check.get_check_compliance``."""
from prowler.lib.outputs.compliance.compliance import (
get_check_compliance as via_compliance,
)
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance as via_helper,
)
assert via_compliance is via_helper
def test_re_export_from_finding_module(self):
"""``finding.get_check_compliance`` must point to the same
function. Test mocks rely on this attribute existing on the
``prowler.lib.outputs.finding`` module."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance as via_helper,
)
from prowler.lib.outputs.finding import get_check_compliance as via_finding
assert via_finding is via_helper
def test_returns_empty_dict_on_unknown_check(self):
"""Sanity test of the function logic via the helper module."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance,
)
finding = mock.MagicMock()
finding.check_metadata.CheckID = "unknown_check_id"
result = get_check_compliance(finding, "aws", {})
assert result == {}
def test_filters_by_provider(self):
"""The function returns frameworks only for the matching provider."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance,
)
compliance_aws = mock.MagicMock(
Framework="CIS",
Version="1.4",
Provider="AWS",
Requirements=[mock.MagicMock(Id="2.1.3")],
)
compliance_azure = mock.MagicMock(
Framework="CIS",
Version="2.0",
Provider="Azure",
Requirements=[mock.MagicMock(Id="9.1")],
)
finding = mock.MagicMock()
finding.check_metadata.CheckID = "shared_check"
bulk = {
"shared_check": mock.MagicMock(
Compliance=[compliance_aws, compliance_azure]
)
}
# Only AWS frameworks come back
result = get_check_compliance(finding, "aws", bulk)
assert "CIS-1.4" in result
assert "CIS-2.0" not in result
def test_returns_empty_dict_on_exception(self):
"""If iteration raises, the function logs the error and returns
an empty dict (defensive behaviour)."""
from prowler.lib.outputs.compliance.compliance_check import (
get_check_compliance,
)
# bulk_checks_metadata that raises when accessed → defensive path
class Boom:
def __contains__(self, _key):
raise RuntimeError("boom")
finding = mock.MagicMock()
finding.check_metadata.CheckID = "any"
result = get_check_compliance(finding, "aws", Boom())
assert result == {}
@@ -0,0 +1,244 @@
"""Tests for display_compliance_table dispatch logic.
Validates that each compliance framework name is routed to the correct
table renderer via startswith matching, and that the universal early-return
takes precedence when applicable.
"""
from unittest.mock import patch
import pytest
from prowler.lib.check.compliance_models import (
ComplianceFramework,
OutputsConfig,
TableConfig,
UniversalComplianceRequirement,
)
from prowler.lib.outputs.compliance.compliance import display_compliance_table
MODULE = "prowler.lib.outputs.compliance.compliance"
# Common args shared by every call — the actual values don't matter
# because we mock the downstream renderers.
_COMMON = dict(
findings=[],
bulk_checks_metadata={},
output_filename="out",
output_directory="/tmp",
compliance_overview=False,
)
# ── Dispatch to legacy table renderers ───────────────────────────────
class TestDispatchStartswith:
"""Each framework prefix must route to exactly one renderer."""
@pytest.mark.parametrize(
"framework_name",
[
"cis_1.4_aws",
"cis_2.0_azure",
"cis_3.0_gcp",
"cis_6.0_m365",
"cis_1.10_kubernetes",
],
)
@patch(f"{MODULE}.get_cis_table")
def test_cis_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["ens_rd2022_aws", "ens_rd2022_azure", "ens_rd2022_gcp"],
)
@patch(f"{MODULE}.get_ens_table")
def test_ens_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["mitre_attack_aws", "mitre_attack_azure", "mitre_attack_gcp"],
)
@patch(f"{MODULE}.get_mitre_attack_table")
def test_mitre_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["kisa_isms_p_2023_aws", "kisa_isms_p_2023_korean_aws"],
)
@patch(f"{MODULE}.get_kisa_ismsp_table")
def test_kisa_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
[
"prowler_threatscore_aws",
"prowler_threatscore_azure",
"prowler_threatscore_gcp",
"prowler_threatscore_kubernetes",
"prowler_threatscore_m365",
"prowler_threatscore_alibabacloud",
],
)
@patch(f"{MODULE}.get_prowler_threatscore_table")
def test_threatscore_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
[
"csa_ccm_4.0_aws",
"csa_ccm_4.0_azure",
"csa_ccm_4.0_gcp",
"csa_ccm_4.0_oraclecloud",
"csa_ccm_4.0_alibabacloud",
],
)
@patch(f"{MODULE}.get_csa_table")
def test_csa_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
["c5_aws", "c5_azure", "c5_gcp"],
)
@patch(f"{MODULE}.get_c5_table")
def test_c5_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
@pytest.mark.parametrize(
"framework_name",
[
"soc2_aws",
"hipaa_aws",
"gdpr_aws",
"nist_800_53_revision_4_aws",
"pci_3.2.1_aws",
"iso27001_2013_aws",
"aws_well_architected_framework_security_pillar_aws",
"fedramp_low_revision_4_aws",
"cisa_aws",
],
)
@patch(f"{MODULE}.get_generic_compliance_table")
def test_generic_dispatch(self, mock_fn, framework_name):
display_compliance_table(compliance_framework=framework_name, **_COMMON)
mock_fn.assert_called_once()
# ── No false matches (the old `in` bug) ─────────────────────────────
class TestNoFalseSubstringMatches:
"""Frameworks that previously could false-match with `in` must NOT
be routed to the wrong renderer now that we use startswith."""
@patch(f"{MODULE}.get_ens_table")
@patch(f"{MODULE}.get_generic_compliance_table")
def test_cisa_does_not_match_cis(self, mock_generic, mock_cis):
"""'cisa_aws' must NOT match startswith('cis_')."""
display_compliance_table(compliance_framework="cisa_aws", **_COMMON)
mock_generic.assert_called_once()
mock_cis.assert_not_called()
@patch(f"{MODULE}.get_prowler_threatscore_table")
@patch(f"{MODULE}.get_generic_compliance_table")
def test_threatscore_prefix_not_partial(self, mock_generic, mock_ts):
"""A hypothetical 'threatscore_custom_aws' must NOT match
startswith('prowler_threatscore_')."""
display_compliance_table(
compliance_framework="threatscore_custom_aws", **_COMMON
)
mock_generic.assert_called_once()
mock_ts.assert_not_called()
@patch(f"{MODULE}.get_ens_table")
@patch(f"{MODULE}.get_prowler_threatscore_table")
def test_prowler_threatscore_does_not_match_ens(self, mock_ts, mock_ens):
"""'prowler_threatscore_aws' must hit threatscore, never ens."""
display_compliance_table(
compliance_framework="prowler_threatscore_aws", **_COMMON
)
mock_ts.assert_called_once()
mock_ens.assert_not_called()
# ── Universal early-return ───────────────────────────────────────────
class TestUniversalEarlyReturn:
"""The universal path must take precedence over the elif chain."""
@staticmethod
def _make_fw():
return ComplianceFramework(
framework="CIS",
name="CIS",
provider="AWS",
version="5.0",
description="d",
requirements=[
UniversalComplianceRequirement(
id="1.1",
description="d",
attributes={},
checks={"aws": ["check_a"]},
),
],
outputs=OutputsConfig(table_config=TableConfig(group_by="_default")),
)
@patch(f"{MODULE}.get_universal_table")
@patch(f"{MODULE}.get_cis_table")
def test_universal_takes_precedence_over_cis(self, mock_cis, mock_universal):
"""A CIS framework in universal_frameworks with TableConfig must
use the universal renderer, not get_cis_table."""
fw = self._make_fw()
display_compliance_table(
compliance_framework="cis_5.0_aws",
universal_frameworks={"cis_5.0_aws": fw},
**_COMMON,
)
mock_universal.assert_called_once()
mock_cis.assert_not_called()
@patch(f"{MODULE}.get_universal_table")
@patch(f"{MODULE}.get_cis_table")
def test_falls_through_without_table_config(self, mock_cis, mock_universal):
"""If the universal framework has no TableConfig, fall through
to the legacy elif chain."""
fw = self._make_fw()
fw.outputs = None
display_compliance_table(
compliance_framework="cis_5.0_aws",
universal_frameworks={"cis_5.0_aws": fw},
**_COMMON,
)
mock_cis.assert_called_once()
mock_universal.assert_not_called()
@patch(f"{MODULE}.get_universal_table")
@patch(f"{MODULE}.get_generic_compliance_table")
def test_falls_through_when_not_in_universal_dict(
self, mock_generic, mock_universal
):
"""If universal_frameworks is empty, fall through to legacy."""
display_compliance_table(
compliance_framework="soc2_aws",
universal_frameworks={},
**_COMMON,
)
mock_generic.assert_called_once()
mock_universal.assert_not_called()
@@ -0,0 +1,128 @@
from io import StringIO
from unittest import mock
from freezegun import freeze_time
from mock import patch
from prowler.lib.outputs.compliance.essential_eight.essential_eight_aws import (
EssentialEightAWS,
)
from prowler.lib.outputs.compliance.essential_eight.models import (
EssentialEightAWSModel,
)
from tests.lib.outputs.compliance.fixtures import ESSENTIAL_EIGHT_AWS
from tests.lib.outputs.fixtures.fixtures import generate_finding_output
from tests.providers.aws.utils import AWS_ACCOUNT_NUMBER, AWS_REGION_EU_WEST_1
# The fixture's first Requirement maps clause "E8-1.8" (Patch applications,
# clause 8: removal of unsupported online services). The second Requirement is
# E8-6.1 (Restrict Office macros, clause 1) which has no Checks and is therefore
# emitted as a manual row.
COMPLIANCE_NAME = "Essential-Eight-Nov 2023"
class TestEssentialEightAWS:
def test_output_transform(self):
findings = [generate_finding_output(compliance={COMPLIANCE_NAME: "E8-1.8"})]
output = EssentialEightAWS(findings, ESSENTIAL_EIGHT_AWS)
output_data = output.data[0]
assert isinstance(output_data, EssentialEightAWSModel)
assert output_data.Provider == "aws"
assert output_data.Framework == ESSENTIAL_EIGHT_AWS.Framework
assert output_data.Name == ESSENTIAL_EIGHT_AWS.Name
assert output_data.Description == ESSENTIAL_EIGHT_AWS.Description
assert output_data.AccountId == AWS_ACCOUNT_NUMBER
assert output_data.Region == AWS_REGION_EU_WEST_1
assert output_data.Requirements_Id == "E8-1.8"
assert (
output_data.Requirements_Description
== ESSENTIAL_EIGHT_AWS.Requirements[0].Description
)
assert output_data.Requirements_Attributes_Section == "1 Patch applications"
assert output_data.Requirements_Attributes_MaturityLevel == "ML1"
assert output_data.Requirements_Attributes_AssessmentStatus == "Automated"
assert output_data.Requirements_Attributes_CloudApplicability == "full"
assert (
output_data.Requirements_Attributes_MitigatedThreats
== "Use of unsupported software, Long-tail vulnerability accumulation"
)
assert (
output_data.Requirements_Attributes_Description
== ESSENTIAL_EIGHT_AWS.Requirements[0].Attributes[0].Description
)
assert output_data.Status == "PASS"
assert output_data.StatusExtended == ""
assert output_data.ResourceId == ""
assert output_data.ResourceName == ""
assert output_data.CheckId == "service_test_check_id"
assert not output_data.Muted
def test_manual_requirement(self):
findings = [generate_finding_output(compliance={COMPLIANCE_NAME: "E8-1.8"})]
output = EssentialEightAWS(findings, ESSENTIAL_EIGHT_AWS)
# E8-6.1 (macros) has no Checks -> emitted as a manual row, non-applicable
manual_rows = [row for row in output.data if row.Status == "MANUAL"]
assert len(manual_rows) == 1
manual = manual_rows[0]
assert manual.Provider == "aws"
assert manual.AccountId == ""
assert manual.Region == ""
assert manual.Requirements_Id == "E8-6.1"
assert (
manual.Requirements_Attributes_Section
== "6 Restrict Microsoft Office macros"
)
assert manual.Requirements_Attributes_MaturityLevel == "ML1"
assert manual.Requirements_Attributes_AssessmentStatus == "Manual"
assert manual.Requirements_Attributes_CloudApplicability == "non-applicable"
assert (
manual.Requirements_Attributes_MitigatedThreats
== "Macro-based malware delivery"
)
assert manual.StatusExtended == "Manual check"
assert manual.ResourceId == "manual_check"
assert manual.ResourceName == "Manual check"
assert manual.CheckId == "manual"
assert not manual.Muted
@freeze_time("2025-01-01 00:00:00")
@mock.patch(
"prowler.lib.outputs.compliance.essential_eight.essential_eight_aws.timestamp",
"2025-01-01 00:00:00",
)
def test_batch_write_data_to_file(self):
mock_file = StringIO()
findings = [generate_finding_output(compliance={COMPLIANCE_NAME: "E8-1.8"})]
output = EssentialEightAWS(findings, ESSENTIAL_EIGHT_AWS)
output._file_descriptor = mock_file
with patch.object(mock_file, "close", return_value=None):
output.batch_write_data_to_file()
mock_file.seek(0)
content = mock_file.read()
# Validate header carries the E8-specific column names
first_line = content.split("\r\n", 1)[0]
for column in (
"REQUIREMENTS_ATTRIBUTES_MATURITYLEVEL",
"REQUIREMENTS_ATTRIBUTES_ASSESSMENTSTATUS",
"REQUIREMENTS_ATTRIBUTES_CLOUDAPPLICABILITY",
"REQUIREMENTS_ATTRIBUTES_MITIGATEDTHREATS",
"REQUIREMENTS_ATTRIBUTES_RATIONALESTATEMENT",
"REQUIREMENTS_ATTRIBUTES_REMEDIATIONPROCEDURE",
"REQUIREMENTS_ATTRIBUTES_AUDITPROCEDURE",
):
assert column in first_line, f"missing column {column} in CSV header"
# rows: header + matched + manual
rows = [r for r in content.split("\r\n") if r]
assert len(rows) == 3
assert rows[1].split(";")[0] == "aws"
assert "ML1" in rows[1]
assert ";PASS;" in rows[1]
assert ";MANUAL;" in rows[2]
assert ";manual_check;" in rows[2]
+56
View File
@@ -7,6 +7,7 @@ from prowler.lib.check.compliance_models import (
ENS_Requirement_Attribute,
ENS_Requirement_Attribute_Nivel,
ENS_Requirement_Attribute_Tipos,
EssentialEight_Requirement_Attribute,
Generic_Compliance_Requirement_Attribute,
ISO27001_2013_Requirement_Attribute,
KISA_ISMSP_Requirement_Attribute,
@@ -1189,3 +1190,58 @@ CCC_GCP_FIXTURE = Compliance(
),
],
)
ESSENTIAL_EIGHT_AWS = Compliance(
Framework="Essential-Eight",
Name="ASD Essential Eight Maturity Model - Maturity Level One (AWS)",
Version="Nov 2023",
Provider="AWS",
Description="Literal mapping of the Australian Signals Directorate (ASD) Essential Eight Maturity Model ML1 to AWS infrastructure checks.",
Requirements=[
Compliance_Requirement(
Id="E8-1.8",
Description="Online services that are no longer supported by vendors are removed.",
Attributes=[
EssentialEight_Requirement_Attribute(
Section="1 Patch applications",
MaturityLevel="ML1",
AssessmentStatus="Automated",
CloudApplicability="full",
MitigatedThreats=[
"Use of unsupported software",
"Long-tail vulnerability accumulation",
],
Description="Detect and remove unsupported AWS-hosted online services (Lambda runtimes, RDS engines, EKS, Fargate, Kafka, OpenSearch).",
RationaleStatement="Unsupported services no longer receive security patches.",
ImpactStatement="",
RemediationProcedure="Migrate Lambda off deprecated runtimes; remove RDS Extended Support; upgrade EKS.",
AuditProcedure="Run all listed checks.",
AdditionalInformation="ASD Essential Eight ML1 - Patch applications - clause 8.",
References="https://www.cyber.gov.au/resources-business-and-government/essential-cyber-security/essential-eight/essential-eight-maturity-model",
)
],
Checks=["service_test_check_id"],
),
Compliance_Requirement(
Id="E8-6.1",
Description="Microsoft Office macros are disabled for users that do not have a demonstrated business requirement.",
Attributes=[
EssentialEight_Requirement_Attribute(
Section="6 Restrict Microsoft Office macros",
MaturityLevel="ML1",
AssessmentStatus="Manual",
CloudApplicability="non-applicable",
MitigatedThreats=["Macro-based malware delivery"],
Description="Endpoint / Microsoft 365 control. Out of AWS infrastructure scope.",
RationaleStatement="Most users never need Office macros.",
ImpactStatement="",
RemediationProcedure="Disable macros via Group Policy / Intune / M365 admin policies.",
AuditProcedure="Manual review of M365 macro policy.",
AdditionalInformation="ASD Essential Eight ML1 - Restrict Microsoft Office macros - clause 1. Out of AWS infrastructure scope.",
References="https://www.cyber.gov.au/resources-business-and-government/essential-cyber-security/essential-eight/essential-eight-maturity-model",
)
],
Checks=[],
),
],
)
@@ -0,0 +1,730 @@
"""Tests for process_universal_compliance_frameworks and --list-compliance fixes.
Validates that the pre-processing step:
- generates both CSV and OCSF outputs for universal frameworks
- always generates OCSF (no output-format gate)
- skips frameworks without outputs or table_config
- skips frameworks not in universal_frameworks
- returns the set of processed names for removal from the legacy loop
- works across different providers
Also validates that print_compliance_frameworks and print_compliance_requirements
work with universal ComplianceFramework objects (dict checks, None provider).
"""
import json
import os
from datetime import datetime, timezone
from types import SimpleNamespace
import pytest
from prowler.lib.check.check import (
print_compliance_frameworks,
print_compliance_requirements,
)
from prowler.lib.check.compliance_models import (
AttributeMetadata,
ComplianceFramework,
OutputsConfig,
TableConfig,
UniversalComplianceRequirement,
)
from prowler.lib.outputs.compliance.compliance import (
process_universal_compliance_frameworks,
)
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
OCSFComplianceOutput,
)
from prowler.lib.outputs.compliance.universal.universal_output import (
UniversalComplianceOutput,
)
@pytest.fixture(autouse=True)
def _create_compliance_dir(tmp_path):
"""Ensure the compliance/ subdirectory exists before each test."""
os.makedirs(tmp_path / "compliance", exist_ok=True)
# ── Helpers ──────────────────────────────────────────────────────────
def _make_finding(check_id, status="PASS", provider="aws"):
"""Create a mock Finding with all fields needed by both output classes."""
finding = SimpleNamespace()
finding.provider = provider
finding.account_uid = "123456789012"
finding.account_name = "test-account"
finding.account_email = ""
finding.account_organization_uid = "org-123"
finding.account_organization_name = "test-org"
finding.account_tags = {"env": "test"}
finding.region = "us-east-1"
finding.status = status
finding.status_extended = f"{check_id} is {status}"
finding.resource_uid = f"arn:aws:iam::123456789012:{check_id}"
finding.resource_name = check_id
finding.resource_details = "some details"
finding.resource_metadata = {}
finding.resource_tags = {"Name": "test"}
finding.partition = "aws"
finding.muted = False
finding.check_id = check_id
finding.uid = "test-finding-uid"
finding.timestamp = datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
finding.prowler_version = "5.0.0"
finding.compliance = {"TestFW-1.0": ["1.1"]}
finding.metadata = SimpleNamespace(
Provider=provider,
CheckID=check_id,
CheckTitle=f"Title for {check_id}",
CheckType=["test-type"],
Description=f"Description for {check_id}",
Severity="medium",
ServiceName="iam",
ResourceType="aws-iam-role",
Risk="test-risk",
RelatedUrl="https://example.com",
Remediation=SimpleNamespace(
Recommendation=SimpleNamespace(Text="Fix it", Url="https://fix.com"),
),
DependsOn=[],
RelatedTo=[],
Categories=["test"],
Notes="",
AdditionalURLs=[],
)
return finding
def _make_universal_framework(name="TestFW", version="1.0", with_table_config=True):
"""Build a ComplianceFramework with optional table_config."""
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="Test requirement",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
]
metadata = [AttributeMetadata(key="Section", type="str")]
outputs = None
if with_table_config:
outputs = OutputsConfig(table_config=TableConfig(group_by="Section"))
return ComplianceFramework(
framework=name,
name=f"{name} Framework",
provider="AWS",
version=version,
description="Test framework",
requirements=reqs,
attributes_metadata=metadata,
outputs=outputs,
)
# ── Tests ────────────────────────────────────────────────────────────
class TestProcessUniversalComplianceFrameworks:
"""Core tests for the extracted pre-processing function."""
def test_generates_csv_and_ocsf_outputs(self, tmp_path):
"""Both CSV and OCSF outputs are appended to generated_outputs."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
assert len(generated["compliance"]) == 2
assert isinstance(generated["compliance"][0], UniversalComplianceOutput)
assert isinstance(generated["compliance"][1], OCSFComplianceOutput)
def test_ocsf_always_generated_no_format_gate(self, tmp_path):
"""OCSF output is generated regardless of output_formats — no gate."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
ocsf_outputs = [
o for o in generated["compliance"] if isinstance(o, OCSFComplianceOutput)
]
assert len(ocsf_outputs) == 1
def test_csv_file_written(self, tmp_path):
"""CSV file is created with expected content."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
csv_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.csv"
assert csv_path.exists()
content = csv_path.read_text()
assert "PROVIDER" in content
assert "REQUIREMENTS_ATTRIBUTES_SECTION" in content
def test_ocsf_file_written(self, tmp_path):
"""OCSF JSON file is created with valid content."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
ocsf_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.ocsf.json"
assert ocsf_path.exists()
data = json.loads(ocsf_path.read_text())
assert isinstance(data, list)
assert len(data) >= 1
assert data[0]["class_uid"] == 2003
def test_returns_processed_names(self, tmp_path):
"""Returns the set of framework names that were processed."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0", "legacy_fw"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
assert "legacy_fw" not in processed
class TestSkipConditions:
"""Tests for frameworks that should NOT be processed."""
def test_skips_framework_not_in_universal(self, tmp_path):
"""Frameworks not in universal_frameworks dict are skipped."""
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"cis_aws_1.4"},
universal_frameworks={},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
def test_skips_framework_without_outputs(self, tmp_path):
"""Frameworks with outputs=None are skipped."""
fw = _make_universal_framework(with_table_config=False)
# outputs is None since with_table_config=False
assert fw.outputs is None
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
def test_skips_framework_with_outputs_but_no_table_config(self, tmp_path):
"""Frameworks with outputs but table_config=None are skipped."""
fw = _make_universal_framework()
# Manually set table_config to None while keeping outputs
fw.outputs = OutputsConfig(table_config=None)
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
def test_empty_input_frameworks(self, tmp_path):
"""No processing when input set is empty."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks=set(),
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == set()
assert len(generated["compliance"]) == 0
class TestMixedFrameworks:
"""Tests with a mix of universal and legacy frameworks."""
def test_only_universal_processed_legacy_untouched(self, tmp_path):
"""Only universal frameworks are processed; legacy names are not returned."""
universal_fw = _make_universal_framework()
generated = {"compliance": []}
all_frameworks = {"test_fw_1.0", "cis_aws_1.4", "nist_800_53_aws"}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks=all_frameworks,
universal_frameworks={"test_fw_1.0": universal_fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
# 2 outputs for the one universal framework (CSV + OCSF)
assert len(generated["compliance"]) == 2
def test_removal_from_input_set(self, tmp_path):
"""Caller can subtract processed set from input to get legacy-only frameworks."""
universal_fw = _make_universal_framework()
generated = {"compliance": []}
input_frameworks = {"test_fw_1.0", "cis_aws_1.4", "nist_800_53_aws"}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks=input_frameworks,
universal_frameworks={"test_fw_1.0": universal_fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
remaining = input_frameworks - processed
assert remaining == {"cis_aws_1.4", "nist_800_53_aws"}
def test_multiple_universal_frameworks(self, tmp_path):
"""Multiple universal frameworks each get CSV + OCSF."""
fw1 = _make_universal_framework(name="FW1", version="1.0")
fw2 = _make_universal_framework(name="FW2", version="2.0")
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"fw1_1.0", "fw2_2.0", "legacy"},
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"fw1_1.0", "fw2_2.0"}
# 2 frameworks × 2 outputs each = 4
assert len(generated["compliance"]) == 4
csv_outputs = [
o
for o in generated["compliance"]
if isinstance(o, UniversalComplianceOutput)
]
ocsf_outputs = [
o for o in generated["compliance"] if isinstance(o, OCSFComplianceOutput)
]
assert len(csv_outputs) == 2
assert len(ocsf_outputs) == 2
class TestProviderVariants:
"""Verify the function works for different providers."""
@pytest.mark.parametrize(
"provider",
[
"aws",
"azure",
"gcp",
"kubernetes",
"m365",
"github",
"oraclecloud",
"alibabacloud",
"nhn",
],
)
def test_all_providers_produce_outputs(self, tmp_path, provider):
"""Each provider generates CSV + OCSF when given a universal framework."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a", provider=provider)],
output_directory=str(tmp_path),
output_filename="out",
provider=provider,
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
assert len(generated["compliance"]) == 2
assert isinstance(generated["compliance"][0], UniversalComplianceOutput)
assert isinstance(generated["compliance"][1], OCSFComplianceOutput)
class TestEmptyFindings:
"""Test behavior when there are no findings."""
def test_still_processed_with_empty_findings(self, tmp_path):
"""Framework is still marked as processed even with no findings."""
fw = _make_universal_framework()
generated = {"compliance": []}
processed = process_universal_compliance_frameworks(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
assert processed == {"test_fw_1.0"}
# Outputs are still appended (they'll just have empty data)
assert len(generated["compliance"]) == 2
class TestFilePaths:
"""Verify correct file path construction."""
def test_csv_path_format(self, tmp_path):
"""CSV output has the correct file path."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"csa_ccm_4.0"},
universal_frameworks={"csa_ccm_4.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_report",
provider="aws",
generated_outputs=generated,
)
csv_output = generated["compliance"][0]
assert csv_output.file_path == (
f"{tmp_path}/compliance/prowler_report_csa_ccm_4.0.csv"
)
def test_ocsf_path_format(self, tmp_path):
"""OCSF output has the correct file path."""
fw = _make_universal_framework()
generated = {"compliance": []}
process_universal_compliance_frameworks(
input_compliance_frameworks={"csa_ccm_4.0"},
universal_frameworks={"csa_ccm_4.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_report",
provider="aws",
generated_outputs=generated,
)
ocsf_output = generated["compliance"][1]
assert ocsf_output.file_path == (
f"{tmp_path}/compliance/prowler_report_csa_ccm_4.0.ocsf.json"
)
# ── Tests for --list-compliance fix ──────────────────────────────────
def _make_legacy_compliance():
"""Create a mock legacy Compliance-like object with the expected attributes."""
return SimpleNamespace(
Framework="CIS",
Provider="AWS",
Version="1.4",
Requirements=[
SimpleNamespace(
Id="2.1.3",
Description="Ensure MFA Delete is enabled",
Checks=["s3_bucket_mfa_delete"],
),
],
)
class TestPrintComplianceFrameworks:
"""Tests for print_compliance_frameworks with universal frameworks."""
def test_includes_universal_frameworks(self, capsys):
"""Universal frameworks appear in the listing."""
legacy = {"cis_1.4_aws": _make_legacy_compliance()}
universal = {"csa_ccm_4.0": _make_universal_framework()}
merged = {**legacy, **universal}
print_compliance_frameworks(merged)
captured = capsys.readouterr().out
assert "cis_1.4_aws" in captured
assert "csa_ccm_4.0" in captured
def test_count_includes_both(self, capsys):
"""Framework count includes both legacy and universal."""
legacy = {"cis_1.4_aws": _make_legacy_compliance()}
universal = {"csa_ccm_4.0": _make_universal_framework()}
merged = {**legacy, **universal}
print_compliance_frameworks(merged)
captured = capsys.readouterr().out
assert "2" in captured
def test_universal_only(self, capsys):
"""Works when only universal frameworks are present."""
universal = {"csa_ccm_4.0": _make_universal_framework()}
print_compliance_frameworks(universal)
captured = capsys.readouterr().out
assert "csa_ccm_4.0" in captured
assert "1" in captured
class TestPrintComplianceRequirements:
"""Tests for print_compliance_requirements with universal frameworks."""
def test_list_checks_universal_framework(self, capsys):
"""Requirements with dict checks are printed correctly."""
fw = _make_universal_framework()
all_fw = {"test_fw_1.0": fw}
print_compliance_requirements(all_fw, ["test_fw_1.0"])
captured = capsys.readouterr().out
assert "1.1" in captured
assert "check_a" in captured
def test_dict_checks_universal_framework(self, capsys):
"""Requirements with dict checks show provider-prefixed checks."""
reqs = [
UniversalComplianceRequirement(
id="A&A-01",
description="Audit & Assurance",
attributes={"Section": "A&A"},
checks={"aws": ["check_a", "check_b"], "azure": ["check_c"]},
),
]
fw = ComplianceFramework(
framework="CSA_CCM",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=reqs,
)
all_fw = {"csa_ccm_4.0": fw}
print_compliance_requirements(all_fw, ["csa_ccm_4.0"])
captured = capsys.readouterr().out
assert "A&A-01" in captured
assert "[aws] check_a" in captured
assert "[aws] check_b" in captured
assert "[azure] check_c" in captured
def test_none_provider_shows_multi_provider(self, capsys):
"""Frameworks with provider=None show 'Multi-provider'."""
fw = ComplianceFramework(
framework="CSA_CCM",
name="CSA CCM 4.0",
version="4.0",
description="Cloud Controls Matrix",
requirements=[
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={},
checks={"aws": ["check_a"]},
),
],
)
all_fw = {"csa_ccm_4.0": fw}
print_compliance_requirements(all_fw, ["csa_ccm_4.0"])
captured = capsys.readouterr().out
assert "Multi-provider" in captured
# ── Idempotency tests ────────────────────────────────────────────────
class TestIdempotency:
"""The function must be safe to invoke multiple times for the same
framework. Repeated calls must reuse writers tracked in
``generated_outputs["compliance"]`` instead of recreating them.
This guards against:
- duplicate writer entries in generated_outputs (regular pipeline
treats one writer per framework)
- the OCSF append-bug where a second writer would emit
``[...]<new>...]`` and break the JSON array.
"""
def test_second_call_does_not_duplicate_writers(self, tmp_path):
fw = _make_universal_framework()
generated = {"compliance": []}
kwargs = dict(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
first = process_universal_compliance_frameworks(**kwargs)
first_count = len(generated["compliance"])
second = process_universal_compliance_frameworks(**kwargs)
second_count = len(generated["compliance"])
assert first == {"test_fw_1.0"}
assert second == {"test_fw_1.0"} # still reported as processed
assert first_count == 2 # CSV + OCSF
assert second_count == 2 # NO duplication
def test_second_call_keeps_ocsf_json_valid(self, tmp_path):
"""End-to-end: after two calls the OCSF JSON file must still be
a single, valid JSON array not the broken ``[...]...]`` form."""
fw = _make_universal_framework()
generated = {"compliance": []}
kwargs = dict(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
process_universal_compliance_frameworks(**kwargs)
process_universal_compliance_frameworks(**kwargs)
ocsf_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.ocsf.json"
data = json.loads(ocsf_path.read_text()) # Will raise on invalid JSON
assert isinstance(data, list)
assert len(data) >= 1
def test_reuses_existing_writer_object(self, tmp_path):
"""The CSV/OCSF writer instances appended on first call must be
the SAME objects after a second call not fresh ones."""
fw = _make_universal_framework()
generated = {"compliance": []}
kwargs = dict(
input_compliance_frameworks={"test_fw_1.0"},
universal_frameworks={"test_fw_1.0": fw},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="prowler_output",
provider="aws",
generated_outputs=generated,
)
process_universal_compliance_frameworks(**kwargs)
first_writers = list(generated["compliance"])
process_universal_compliance_frameworks(**kwargs)
second_writers = list(generated["compliance"])
# Same identity, same length — reused, not recreated.
assert len(first_writers) == len(second_writers)
for a, b in zip(first_writers, second_writers):
assert a is b
def test_idempotency_across_mixed_frameworks(self, tmp_path):
"""When the second call adds a new framework, the new one is
created while existing ones are NOT recreated."""
fw1 = _make_universal_framework(name="FW1", version="1.0")
fw2 = _make_universal_framework(name="FW2", version="2.0")
generated = {"compliance": []}
# First call: only FW1
process_universal_compliance_frameworks(
input_compliance_frameworks={"fw1_1.0"},
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
first_writers = list(generated["compliance"])
assert len(first_writers) == 2
# Second call: includes both. FW1 must be reused, FW2 created fresh.
process_universal_compliance_frameworks(
input_compliance_frameworks={"fw1_1.0", "fw2_2.0"},
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
finding_outputs=[_make_finding("check_a")],
output_directory=str(tmp_path),
output_filename="out",
provider="aws",
generated_outputs=generated,
)
second_writers = list(generated["compliance"])
assert len(second_writers) == 4 # 2 (FW1 reused) + 2 new (FW2)
# FW1 writer instances unchanged
assert second_writers[0] is first_writers[0]
assert second_writers[1] is first_writers[1]
@@ -2,6 +2,7 @@ import json
from datetime import datetime, timezone
from types import SimpleNamespace
from py_ocsf_models.events.base_event import StatusID as EventStatusID
from py_ocsf_models.events.findings.compliance_finding import ComplianceFinding
from py_ocsf_models.events.findings.compliance_finding_type_id import (
ComplianceFindingTypeID,
@@ -18,6 +19,7 @@ from prowler.lib.check.compliance_models import (
)
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
OCSFComplianceOutput,
_sanitize_resource_data,
)
@@ -473,3 +475,159 @@ class TestOCSFComplianceOutput:
cf = output.data[0]
assert cf.unmapped["requirement_attributes"]["section"] == "Logging"
assert "internal_note" not in cf.unmapped["requirement_attributes"]
class TestSanitizeResourceData:
"""Unit tests for the _sanitize_resource_data helper.
Service resources may carry non-JSON-serializable objects (e.g. raw
Pydantic models such as ``Trail`` or ``LifecyclePolicy``). The helper
must convert them so the resulting ComplianceFinding can be serialized.
"""
def test_dict_passthrough(self):
result = _sanitize_resource_data("details", {"a": 1, "b": "two"})
assert result == {"details": "details", "metadata": {"a": 1, "b": "two"}}
def test_none_metadata(self):
result = _sanitize_resource_data("details", None)
assert result == {"details": "details", "metadata": None}
def test_pydantic_v2_model_dump(self):
class FakeV2Model:
def model_dump(self):
return {"name": "trail-1", "region": "us-east-1"}
result = _sanitize_resource_data("d", {"trail": FakeV2Model()})
assert result["metadata"]["trail"] == {
"name": "trail-1",
"region": "us-east-1",
}
def test_pydantic_v1_dict(self):
class FakeV1Model:
def dict(self):
return {"name": "policy-1", "schedule": "daily"}
result = _sanitize_resource_data("d", {"policy": FakeV1Model()})
assert result["metadata"]["policy"] == {
"name": "policy-1",
"schedule": "daily",
}
def test_nested_pydantic_in_list(self):
class FakeModel:
def model_dump(self):
return {"id": "x"}
result = _sanitize_resource_data("d", {"items": [FakeModel(), FakeModel()]})
assert result["metadata"]["items"] == [{"id": "x"}, {"id": "x"}]
def test_nested_dict_recursion(self):
class FakeInner:
def model_dump(self):
return {"k": "v"}
result = _sanitize_resource_data(
"d", {"outer": {"inner": FakeInner(), "x": [1, 2]}}
)
assert result["metadata"]["outer"]["inner"] == {"k": "v"}
assert result["metadata"]["outer"]["x"] == [1, 2]
def test_tuple_to_list(self):
result = _sanitize_resource_data("d", {"t": (1, 2, "three")})
assert result["metadata"]["t"] == [1, 2, "three"]
def test_non_string_dict_keys_coerced(self):
result = _sanitize_resource_data("d", {1: "a", 2: "b"})
assert result["metadata"] == {"1": "a", "2": "b"}
def test_unknown_object_falls_back_to_str(self):
class Opaque:
def __str__(self):
return "opaque-repr"
result = _sanitize_resource_data("d", {"thing": Opaque()})
assert result["metadata"]["thing"] == "opaque-repr"
def test_circular_reference_falls_back_to_empty(self):
a = {}
a["self"] = a
# json.dumps raises ValueError on recursion → fallback to empty metadata
result = _sanitize_resource_data("d", a)
assert result == {"details": "d", "metadata": {}}
def test_serializes_via_full_finding_pipeline(self):
"""End-to-end: a finding with a non-serializable resource_metadata
produces a JSON-serializable ComplianceFinding."""
class TrailLike:
def __init__(self):
self.name = "trail-A"
self.kms_key_id = "arn:aws:kms:..."
def model_dump(self):
return {"name": self.name, "kms_key_id": self.kms_key_id}
finding = _make_finding("check_a")
finding.resource_metadata = {"trail": TrailLike()}
req = _simple_requirement()
fw = _make_framework([req])
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
# Serialize the resulting ComplianceFinding — must NOT raise
cf = output.data[0]
if hasattr(cf, "model_dump_json"):
json_output = cf.model_dump_json(exclude_none=True)
else:
json_output = cf.json(exclude_none=True)
payload = json.loads(json_output)
# Confirm the trail object made it through as a plain dict
assert payload["resources"][0]["data"]["metadata"]["trail"]["name"] == "trail-A"
class TestEventStatusInline:
"""Tests for the inlined event_status logic that replaced
OCSF.get_finding_status_id() to break the cyclic import."""
def test_unmuted_finding_status_new(self):
finding = _make_finding("check_a")
finding.muted = False
req = _simple_requirement()
fw = _make_framework([req])
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
cf = output.data[0]
assert cf.status_id == EventStatusID.New.value
assert cf.status == EventStatusID.New.name
def test_muted_finding_status_suppressed(self):
finding = _make_finding("check_a")
finding.muted = True
req = _simple_requirement()
fw = _make_framework([req])
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
cf = output.data[0]
assert cf.status_id == EventStatusID.Suppressed.value
assert cf.status == EventStatusID.Suppressed.name
class TestNoTopLevelOCSFImport:
"""Regression test: the top-level OCSF/Finding imports were removed
to break the CodeQL cyclic-import warnings. Ensure they stay out of
the runtime namespace of the module (TYPE_CHECKING block only)."""
def test_finding_not_in_runtime_namespace(self):
import prowler.lib.outputs.compliance.universal.ocsf_compliance as mod
assert "Finding" not in dir(mod)
def test_ocsf_class_not_imported(self):
import prowler.lib.outputs.compliance.universal.ocsf_compliance as mod
assert "OCSF" not in dir(mod)
@@ -0,0 +1,568 @@
from types import SimpleNamespace
from prowler.lib.check.compliance_models import (
AttributeMetadata,
ComplianceFramework,
OutputFormats,
OutputsConfig,
TableConfig,
UniversalComplianceRequirement,
)
from prowler.lib.outputs.compliance.universal.universal_output import (
UniversalComplianceOutput,
)
def _make_finding(check_id, status="PASS", compliance_map=None):
"""Create a mock Finding for output tests."""
finding = SimpleNamespace()
finding.provider = "aws"
finding.account_uid = "123456789012"
finding.account_name = "test-account"
finding.region = "us-east-1"
finding.status = status
finding.status_extended = f"{check_id} is {status}"
finding.resource_uid = f"arn:aws:iam::123456789012:{check_id}"
finding.resource_name = check_id
finding.muted = False
finding.check_id = check_id
finding.metadata = SimpleNamespace(
Provider="aws",
CheckID=check_id,
Severity="medium",
)
finding.compliance = compliance_map or {}
return finding
def _make_framework(requirements, attrs_metadata=None, table_config=None):
return ComplianceFramework(
framework="TestFW",
name="Test Framework",
provider="AWS",
version="1.0",
description="Test framework",
requirements=requirements,
attributes_metadata=attrs_metadata,
outputs=OutputsConfig(table_config=table_config) if table_config else None,
)
class TestDynamicCSVColumns:
def test_columns_match_metadata(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM", "SubSection": "Auth"},
checks={"aws": ["check_a"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
AttributeMetadata(key="SubSection", type="str"),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
assert len(output.data) == 1
row_dict = output.data[0].dict()
assert "Requirements_Attributes_Section" in row_dict
assert "Requirements_Attributes_SubSection" in row_dict
assert row_dict["Requirements_Attributes_Section"] == "IAM"
assert row_dict["Requirements_Attributes_SubSection"] == "Auth"
class TestManualRequirements:
def test_manual_status(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
UniversalComplianceRequirement(
id="manual-1",
description="manual check",
attributes={"Section": "Governance"},
checks={},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
# Should have 1 real finding + 1 manual
assert len(output.data) == 2
manual_rows = [r for r in output.data if r.dict()["Status"] == "MANUAL"]
assert len(manual_rows) == 1
assert manual_rows[0].dict()["Requirements_Id"] == "manual-1"
assert manual_rows[0].dict()["ResourceId"] == "manual_check"
class TestMITREExtraColumns:
def test_mitre_columns_present(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="T1190",
description="Exploit",
attributes={},
checks={"aws": ["check_a"]},
tactics=["Initial Access"],
sub_techniques=[],
platforms=["IaaS"],
technique_url="https://attack.mitre.org/techniques/T1190/",
),
]
fw = _make_framework(reqs, None, TableConfig(group_by="_Tactics"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["T1190"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
assert len(output.data) == 1
row_dict = output.data[0].dict()
assert "Requirements_Tactics" in row_dict
assert row_dict["Requirements_Tactics"] == "Initial Access"
assert "Requirements_TechniqueURL" in row_dict
class TestCSVFileWrite:
def test_batch_write(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
output.batch_write_data_to_file()
# Verify file was created and has content
with open(filepath, "r") as f:
content = f.read()
assert "PROVIDER" in content # Headers are uppercase
assert "REQUIREMENTS_ATTRIBUTES_SECTION" in content
assert "IAM" in content
class TestNoFindings:
def test_empty_findings_no_data(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"]},
),
]
fw = _make_framework(reqs, None, TableConfig(group_by="Section"))
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=[],
framework=fw,
file_path=filepath,
)
assert len(output.data) == 0
class TestMultiProviderOutput:
def test_dict_checks_filtered_by_provider(self, tmp_path):
"""Only checks for the given provider appear in CSV output."""
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"], "azure": ["check_b"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = ComplianceFramework(
framework="MultiCloud",
name="Multi",
version="1.0",
description="Test multi-provider",
requirements=reqs,
attributes_metadata=metadata,
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
)
findings = [
_make_finding("check_a", "PASS", {"MultiCloud-1.0": ["1.1"]}),
_make_finding("check_b", "FAIL", {"MultiCloud-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
provider="aws",
)
# Only check_a should match (it's the AWS check)
assert len(output.data) == 1
row_dict = output.data[0].dict()
assert row_dict["Requirements_Attributes_Section"] == "IAM"
def test_no_provider_includes_all(self, tmp_path):
"""Without provider filter, all checks from all providers are included."""
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={"aws": ["check_a"], "azure": ["check_b"]},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = ComplianceFramework(
framework="MultiCloud",
name="Multi",
version="1.0",
description="Test multi-provider",
requirements=reqs,
attributes_metadata=metadata,
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
)
findings = [
_make_finding("check_a", "PASS", {"MultiCloud-1.0": ["1.1"]}),
_make_finding("check_b", "FAIL", {"MultiCloud-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
# Both checks should be included without provider filter
assert len(output.data) == 2
def test_empty_dict_checks_is_manual(self, tmp_path):
"""Requirement with empty dict checks is treated as manual."""
reqs = [
UniversalComplianceRequirement(
id="manual-1",
description="manual check",
attributes={"Section": "Governance"},
checks={},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
fw = ComplianceFramework(
framework="MultiCloud",
name="Multi",
version="1.0",
description="Test",
requirements=reqs,
attributes_metadata=metadata,
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
)
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=[_make_finding("other_check", "PASS", {})],
framework=fw,
file_path=filepath,
provider="aws",
)
manual_rows = [r for r in output.data if r.dict()["Status"] == "MANUAL"]
assert len(manual_rows) == 1
assert manual_rows[0].dict()["Requirements_Id"] == "manual-1"
class TestCSVExclude:
def test_csv_false_excludes_column(self, tmp_path):
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM", "Internal": "hidden"},
checks={"aws": ["check_a"]},
),
]
metadata = [
AttributeMetadata(
key="Section", type="str", output_formats=OutputFormats(csv=True)
),
AttributeMetadata(
key="Internal", type="str", output_formats=OutputFormats(csv=False)
),
]
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
findings = [
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
)
row_dict = output.data[0].dict()
assert "Requirements_Attributes_Section" in row_dict
assert "Requirements_Attributes_Internal" not in row_dict
def _make_provider_finding(provider, check_id="check_a", status="PASS"):
"""Create a mock Finding with a specific provider."""
finding = _make_finding(check_id, status, {"TestFW-1.0": ["1.1"]})
finding.provider = provider
return finding
def _simple_framework():
all_providers = [
"aws",
"azure",
"gcp",
"kubernetes",
"m365",
"github",
"oraclecloud",
"alibabacloud",
"nhn",
"unknown",
]
reqs = [
UniversalComplianceRequirement(
id="1.1",
description="test",
attributes={"Section": "IAM"},
checks={p: ["check_a"] for p in all_providers},
),
]
metadata = [
AttributeMetadata(key="Section", type="str"),
]
return _make_framework(reqs, metadata, TableConfig(group_by="Section"))
class TestProviderHeaders:
def test_aws_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("aws")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="aws",
)
row_dict = output.data[0].dict()
assert "AccountId" in row_dict
assert "Region" in row_dict
assert row_dict["AccountId"] == "123456789012"
assert row_dict["Region"] == "us-east-1"
def test_azure_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("azure")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="azure",
)
row_dict = output.data[0].dict()
assert "SubscriptionId" in row_dict
assert "Location" in row_dict
assert row_dict["SubscriptionId"] == "123456789012"
assert row_dict["Location"] == "us-east-1"
def test_gcp_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("gcp")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="gcp",
)
row_dict = output.data[0].dict()
assert "ProjectId" in row_dict
assert "Location" in row_dict
assert row_dict["ProjectId"] == "123456789012"
def test_kubernetes_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("kubernetes")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="kubernetes",
)
row_dict = output.data[0].dict()
assert "Context" in row_dict
assert "Namespace" in row_dict
# Kubernetes Context maps to account_name
assert row_dict["Context"] == "test-account"
assert row_dict["Namespace"] == "us-east-1"
def test_github_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("github")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="github",
)
row_dict = output.data[0].dict()
assert "Account_Name" in row_dict
assert "Account_Id" in row_dict
# GitHub: Account_Name (pos 3) from account_name, Account_Id (pos 4) from account_uid
assert row_dict["Account_Name"] == "test-account"
assert row_dict["Account_Id"] == "123456789012"
# Verify column order matches legacy (Account_Name before Account_Id)
keys = list(row_dict.keys())
assert keys.index("Account_Name") < keys.index("Account_Id")
def test_unknown_provider_defaults(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("unknown")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
provider="unknown",
)
row_dict = output.data[0].dict()
assert "AccountId" in row_dict
assert "Region" in row_dict
def test_none_provider_defaults(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("aws")]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / "test.csv"),
)
row_dict = output.data[0].dict()
assert "AccountId" in row_dict
assert "Region" in row_dict
def test_csv_write_azure_headers(self, tmp_path):
fw = _simple_framework()
findings = [_make_provider_finding("azure")]
filepath = str(tmp_path / "test.csv")
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=filepath,
provider="azure",
)
output.batch_write_data_to_file()
with open(filepath, "r") as f:
content = f.read()
assert "SUBSCRIPTIONID" in content
assert "LOCATION" in content
# Should NOT have the default AccountId/Region headers
assert "ACCOUNTID" not in content
def test_column_order_matches_legacy(self, tmp_path):
"""Verify that the base column order matches the legacy per-provider models.
Legacy models all define: Provider, Description, <col3>, <col4>, AssessmentDate, ...
The universal output must preserve this exact order for backward compatibility.
"""
# Expected column order per provider (positions 3 and 4 after Provider, Description)
legacy_order = {
"aws": ("AccountId", "Region"),
"azure": ("SubscriptionId", "Location"),
"gcp": ("ProjectId", "Location"),
"kubernetes": ("Context", "Namespace"),
"m365": ("TenantId", "Location"),
"github": ("Account_Name", "Account_Id"),
"oraclecloud": ("TenancyId", "Region"),
"alibabacloud": ("AccountId", "Region"),
"nhn": ("AccountId", "Region"),
}
for provider_name, (expected_col3, expected_col4) in legacy_order.items():
fw = _simple_framework()
findings = [_make_provider_finding(provider_name)]
output = UniversalComplianceOutput(
findings=findings,
framework=fw,
file_path=str(tmp_path / f"test_{provider_name}.csv"),
provider=provider_name,
)
keys = list(output.data[0].dict().keys())
assert keys[0] == "Provider", f"{provider_name}: col 1 should be Provider"
assert (
keys[1] == "Description"
), f"{provider_name}: col 2 should be Description"
assert (
keys[2] == expected_col3
), f"{provider_name}: col 3 should be {expected_col3}, got {keys[2]}"
assert (
keys[3] == expected_col4
), f"{provider_name}: col 4 should be {expected_col4}, got {keys[3]}"
assert (
keys[4] == "AssessmentDate"
), f"{provider_name}: col 5 should be AssessmentDate"
+7
View File
@@ -21,6 +21,7 @@ from prowler.providers.aws.config import (
AWS_STS_GLOBAL_ENDPOINT_REGION,
BOTO3_USER_AGENT_EXTRA,
ROLE_SESSION_NAME,
get_default_session_config,
)
from prowler.providers.aws.exceptions.exceptions import (
AWSArgumentTypeValidationError,
@@ -2242,6 +2243,12 @@ aws:
assert session_config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
assert session_config.retries == {"max_attempts": 10, "mode": "standard"}
def test_get_default_session_config(self):
config = get_default_session_config()
assert config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
assert config.retries == {"max_attempts": 3, "mode": "standard"}
@mock_aws
@patch(
"prowler.lib.check.utils.recover_checks_from_provider",
@@ -4,6 +4,8 @@ import boto3
from botocore.exceptions import ClientError
from moto import mock_aws
from prowler.providers.aws.aws_provider import AwsProvider
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
from prowler.providers.aws.lib.organizations.organizations import (
_get_ou_metadata,
get_organizations_metadata,
@@ -222,6 +224,20 @@ class Test_AWS_Organizations:
assert tags == {}
assert ou_metadata == {}
def test_get_organizations_metadata_uses_user_agent_extra(self):
real_session = boto3.Session()
real_session._session.set_default_client_config(
AwsProvider.set_session_config(None)
)
wrapper = MagicMock(wraps=real_session)
get_organizations_metadata("123456789012", wrapper)
wrapper.client.assert_called_once()
default_config = real_session._session.get_default_client_config()
assert default_config is not None
assert BOTO3_USER_AGENT_EXTRA in default_config.user_agent_extra
def test_parse_organizations_metadata_with_empty_ou_metadata(self):
tags = {"Tags": []}
metadata = {
@@ -1,5 +1,6 @@
from mock import patch
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
from prowler.providers.aws.lib.service.service import AWSService
from tests.providers.aws.utils import (
AWS_ACCOUNT_ARN,
@@ -189,6 +190,15 @@ class TestAWSService:
== f"arn:{service.audited_partition}:{service_name}::{AWS_ACCOUNT_NUMBER}:bucket/unknown"
)
def test_AWSService_clients_carry_user_agent_extra(self):
provider = set_mocked_aws_provider()
service = AWSService("s3", provider)
ad_hoc_client = service.session.client("ec2", AWS_REGION_US_EAST_1)
assert BOTO3_USER_AGENT_EXTRA in service.client._client_config.user_agent_extra
assert BOTO3_USER_AGENT_EXTRA in ad_hoc_client._client_config.user_agent_extra
def test_AWSService_get_unknown_arn_resource_type_set_region(self):
service_name = "s3"
provider = set_mocked_aws_provider()
@@ -45,11 +45,12 @@ def mock_make_api_call(self, operation_name, kwarg):
elif operation_name == "ListBuildsForProject":
return {"ids": [build_id]}
elif operation_name == "BatchGetBuilds":
return {"builds": [{"endTime": last_invoked_time}]}
return {"builds": [{"id": build_id, "endTime": last_invoked_time}]}
elif operation_name == "BatchGetProjects":
return {
"projects": [
{
"arn": project_arn,
"source": {
"type": source_type,
"location": bitbucket_url,
@@ -230,3 +231,97 @@ class Test_Codebuild_Service:
assert (
codebuild.report_groups[report_group_arn].tags[0]["value"] == project_name
)
# Module-level state and helpers used by the chunking/out-of-order test below.
# Kept at module level so the API-call mock is a plain function rather than a
# closure defined inside the test method.
TOTAL_PROJECTS = 150
many_project_names = [f"project-{i}" for i in range(TOTAL_PROJECTS)]
many_project_arns = [
f"arn:{AWS_COMMERCIAL_PARTITION}:codebuild:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:project/{name}"
for name in many_project_names
]
many_build_ids_for = {name: f"{name}:build-id" for name in many_project_names}
many_end_times_for = {
name: datetime.now() - timedelta(days=i)
for i, name in enumerate(many_project_names)
}
many_name_by_build_id = {v: k for k, v in many_build_ids_for.items()}
many_batch_call_sizes = {"BatchGetProjects": [], "BatchGetBuilds": []}
def mock_make_api_call_many_projects(self, operation_name, kwarg):
if operation_name == "ListProjects":
return {"projects": many_project_names}
if operation_name == "ListBuildsForProject":
return {"ids": [many_build_ids_for[kwarg["projectName"]]]}
if operation_name == "BatchGetBuilds":
ids = kwarg["ids"]
many_batch_call_sizes["BatchGetBuilds"].append(len(ids))
# Reverse the response order to verify id->project mapping does not
# depend on response ordering.
builds = [
{"id": bid, "endTime": many_end_times_for[many_name_by_build_id[bid]]}
for bid in reversed(ids)
]
return {"builds": builds}
if operation_name == "BatchGetProjects":
names = kwarg["names"]
many_batch_call_sizes["BatchGetProjects"].append(len(names))
# Reverse the response order to verify arn->project mapping does not
# depend on response ordering.
projects = [
{
"arn": f"arn:{AWS_COMMERCIAL_PARTITION}:codebuild:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:project/{name}",
"source": {"type": "NO_SOURCE"},
"logsConfig": {},
"tags": [],
"projectVisibility": "PRIVATE",
}
for name in reversed(names)
]
return {"projects": projects}
if operation_name == "ListReportGroups":
return {"reportGroups": []}
return make_api_call(self, operation_name, kwarg)
class Test_Codebuild_Service_Batching:
@patch(
"botocore.client.BaseClient._make_api_call",
new=mock_make_api_call_many_projects,
)
@patch(
"prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients",
new=mock_generate_regional_clients,
)
@mock_aws
def test_codebuild_batches_chunks_over_100_projects_and_maps_out_of_order_responses(
self,
):
"""Verify _batch_get_projects/_batch_get_builds chunk in groups of 100
and correctly map out-of-order batch responses back to the right
project using `arn`/`id`.
"""
# Reset the per-test recorder (module-level state survives across runs).
many_batch_call_sizes["BatchGetProjects"].clear()
many_batch_call_sizes["BatchGetBuilds"].clear()
codebuild = Codebuild(set_mocked_aws_provider([AWS_REGION_EU_WEST_1]))
# Verify chunking: 150 items -> two batches of 100 and 50.
assert sorted(many_batch_call_sizes["BatchGetProjects"]) == [50, 100]
assert sorted(many_batch_call_sizes["BatchGetBuilds"]) == [50, 100]
# Verify all projects were tracked.
assert len(codebuild.projects) == TOTAL_PROJECTS
# Verify out-of-order responses were correctly mapped back to the
# right project by `arn` (projects) and `id` (builds).
for name, arn in zip(many_project_names, many_project_arns):
project = codebuild.projects[arn]
assert project.name == name
assert project.project_visibility == "PRIVATE"
assert project.last_build == Build(id=many_build_ids_for[name])
assert project.last_invoked_time == many_end_times_for[name]
+3 -3
View File
@@ -2,7 +2,6 @@ from argparse import Namespace
from json import dumps
from boto3 import client, session
from botocore.config import Config
from moto import mock_aws
from prowler.config.config import (
@@ -133,10 +132,11 @@ def set_mocked_aws_provider(
provider = AwsProvider()
# Mock Session
provider._session.session_config = None
session_config = AwsProvider.set_session_config(None)
provider._session.session_config = session_config
provider._session.original_session = original_session
provider._session.current_session = audit_session
provider._session.session_config = Config()
audit_session._session.set_default_client_config(session_config)
# Mock Identity
provider._identity.account = audited_account
provider._identity.account_arn = audited_account_arn