mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-05-06 08:47:18 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e488e2ee6 | |||
| 8db3a89669 | |||
| c802dc8a36 | |||
| 3ab9a4efa5 | |||
| 36b8aa1b79 | |||
| e821e07d7d | |||
| 228fe6d579 | |||
| 578186aa40 |
@@ -62,7 +62,7 @@ jobs:
|
||||
"Alan-TheGentleman"
|
||||
"alejandrobailo"
|
||||
"amitsharm"
|
||||
"andoniaf"
|
||||
# "andoniaf"
|
||||
"cesararroba"
|
||||
"danibarranqueroo"
|
||||
"HugoPBrito"
|
||||
|
||||
@@ -98,7 +98,7 @@ repos:
|
||||
|
||||
## PYTHON — API + MCP Server (ruff)
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.11
|
||||
rev: v0.15.12
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: "API + MCP - ruff check"
|
||||
|
||||
@@ -227,6 +227,7 @@ Assign administrative permissions by selecting from the following options:
|
||||
| Manage Integrations | All | Add or modify the Prowler Integrations. |
|
||||
| Manage Ingestions | Prowler Cloud | Allow or deny the ability to submit findings ingestion batches via the API. |
|
||||
| Manage Billing | Prowler Cloud | Access and manage billing settings and subscription information. |
|
||||
| Manage Alerts | Prowler Cloud | Create, edit, and delete alert rules and recipients. |
|
||||
|
||||
<Note>
|
||||
The **Scope** column indicates where each permission applies. **All** means the permission is available in both Prowler Cloud and Self-Managed deployments. **Prowler Cloud** indicates permissions that are specific to [Prowler Cloud](https://cloud.prowler.com/sign-in).
|
||||
@@ -241,3 +242,5 @@ The following permissions are available exclusively in **Prowler Cloud**:
|
||||
**Manage Ingestions:** Submit and manage findings ingestion jobs via the API. Required to upload OCSF scan results using the `--push-to-cloud` CLI flag or the ingestion endpoints. See [Import Findings](/user-guide/tutorials/prowler-app-import-findings) for details.
|
||||
|
||||
**Manage Billing:** Access and manage billing settings, subscription plans, and payment methods.
|
||||
|
||||
**Manage Alerts:** Create, edit, and delete alert rules and recipients used to deliver scan-result digests via email.
|
||||
|
||||
@@ -7,16 +7,21 @@ All notable changes to the **Prowler SDK** are documented in this file.
|
||||
### 🚀 Added
|
||||
|
||||
- `bedrock_guardrails_configured` check for AWS provider [(#10844)](https://github.com/prowler-cloud/prowler/pull/10844)
|
||||
- Universal compliance pipeline integrated into the CLI: `--list-compliance` and `--list-compliance-requirements` show universal frameworks, and CSV plus OCSF outputs are generated for any framework declaring a `TableConfig` [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301)
|
||||
- ASD Essential Eight Maturity Model compliance framework for AWS (Maturity Level One, Nov 2023) [(#10808)](https://github.com/prowler-cloud/prowler/pull/10808)
|
||||
|
||||
### 🔄 Changed
|
||||
|
||||
- `route53_dangling_ip_subdomain_takeover` now also flags `CNAME` records pointing to S3 website endpoints whose buckets are missing from the account [(#10920)](https://github.com/prowler-cloud/prowler/pull/10920)
|
||||
- Azure Network Watcher flow log checks now require workspace-backed Traffic Analytics for `network_flow_log_captured_sent` and align metadata with VNet-compatible flow log guidance [(#10645)](https://github.com/prowler-cloud/prowler/pull/10645)
|
||||
- Azure compliance entries for legacy Network Watcher flow log controls now use retirement-aware guidance and point new deployments to VNet flow logs
|
||||
- AWS CodeBuild service now batches `BatchGetProjects` and `BatchGetBuilds` calls per region (up to 100 items per call) to reduce API call volume and prevent throttling-induced false positives in `codebuild_project_not_publicly_accessible` [(#10639)](https://github.com/prowler-cloud/prowler/pull/10639)
|
||||
- `display_compliance_table` dispatch switched from substring `in` checks to `startswith` to prevent false matches between similarly named frameworks (e.g. `cisa` vs `cis`) [(#10301)](https://github.com/prowler-cloud/prowler/pull/10301)
|
||||
|
||||
### 🐞 Fixed
|
||||
|
||||
- AWS SDK test isolation: autouse `mock_aws` fixture and leak detector in `conftest.py` to prevent tests from hitting real AWS endpoints, with idempotent organization setup for tests calling `set_mocked_aws_provider` multiple times [(#10605)](https://github.com/prowler-cloud/prowler/pull/10605)
|
||||
- AWS `boto` user agent extra is now applied to every client [(#10944)](https://github.com/prowler-cloud/prowler/pull/10944)
|
||||
|
||||
### 🔐 Security
|
||||
|
||||
|
||||
+56
-13
@@ -45,7 +45,10 @@ from prowler.lib.check.check import (
|
||||
)
|
||||
from prowler.lib.check.checks_loader import load_checks_to_execute
|
||||
from prowler.lib.check.compliance import update_checks_metadata_with_compliance
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
from prowler.lib.check.compliance_models import (
|
||||
Compliance,
|
||||
get_bulk_compliance_frameworks_universal,
|
||||
)
|
||||
from prowler.lib.check.custom_checks_metadata import (
|
||||
parse_custom_checks_metadata_file,
|
||||
update_checks_metadata,
|
||||
@@ -75,7 +78,10 @@ from prowler.lib.outputs.compliance.cis.cis_oraclecloud import OracleCloudCIS
|
||||
from prowler.lib.outputs.compliance.cisa_scuba.cisa_scuba_googleworkspace import (
|
||||
GoogleWorkspaceCISASCuBA,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.compliance import display_compliance_table
|
||||
from prowler.lib.outputs.compliance.compliance import (
|
||||
display_compliance_table,
|
||||
process_universal_compliance_frameworks,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.csa.csa_alibabacloud import AlibabaCloudCSA
|
||||
from prowler.lib.outputs.compliance.csa.csa_aws import AWSCSA
|
||||
from prowler.lib.outputs.compliance.csa.csa_azure import AzureCSA
|
||||
@@ -84,6 +90,9 @@ from prowler.lib.outputs.compliance.csa.csa_oraclecloud import OracleCloudCSA
|
||||
from prowler.lib.outputs.compliance.ens.ens_aws import AWSENS
|
||||
from prowler.lib.outputs.compliance.ens.ens_azure import AzureENS
|
||||
from prowler.lib.outputs.compliance.ens.ens_gcp import GCPENS
|
||||
from prowler.lib.outputs.compliance.essential_eight.essential_eight_aws import (
|
||||
EssentialEightAWS,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.generic.generic import GenericCompliance
|
||||
from prowler.lib.outputs.compliance.iso27001.iso27001_aws import AWSISO27001
|
||||
from prowler.lib.outputs.compliance.iso27001.iso27001_azure import AzureISO27001
|
||||
@@ -235,6 +244,8 @@ def prowler():
|
||||
# Load compliance frameworks
|
||||
logger.debug("Loading compliance frameworks from .json files")
|
||||
|
||||
universal_frameworks = {}
|
||||
|
||||
# Skip compliance frameworks for external-tool providers
|
||||
if provider not in EXTERNAL_TOOL_PROVIDERS:
|
||||
bulk_compliance_frameworks = Compliance.get_bulk(provider)
|
||||
@@ -242,6 +253,8 @@ def prowler():
|
||||
bulk_checks_metadata = update_checks_metadata_with_compliance(
|
||||
bulk_compliance_frameworks, bulk_checks_metadata
|
||||
)
|
||||
# Load universal compliance frameworks for new rendering pipeline
|
||||
universal_frameworks = get_bulk_compliance_frameworks_universal(provider)
|
||||
|
||||
# Update checks metadata if the --custom-checks-metadata-file is present
|
||||
custom_checks_metadata = None
|
||||
@@ -254,12 +267,12 @@ def prowler():
|
||||
)
|
||||
|
||||
if args.list_compliance:
|
||||
print_compliance_frameworks(bulk_compliance_frameworks)
|
||||
all_frameworks = {**bulk_compliance_frameworks, **universal_frameworks}
|
||||
print_compliance_frameworks(all_frameworks)
|
||||
sys.exit()
|
||||
if args.list_compliance_requirements:
|
||||
print_compliance_requirements(
|
||||
bulk_compliance_frameworks, args.list_compliance_requirements
|
||||
)
|
||||
all_frameworks = {**bulk_compliance_frameworks, **universal_frameworks}
|
||||
print_compliance_requirements(all_frameworks, args.list_compliance_requirements)
|
||||
sys.exit()
|
||||
|
||||
# Load checks to execute
|
||||
@@ -276,6 +289,7 @@ def prowler():
|
||||
provider=provider,
|
||||
list_checks=getattr(args, "list_checks", False)
|
||||
or getattr(args, "list_checks_json", False),
|
||||
universal_frameworks=universal_frameworks,
|
||||
)
|
||||
|
||||
# if --list-checks-json, dump a json file and exit
|
||||
@@ -624,15 +638,29 @@ def prowler():
|
||||
)
|
||||
|
||||
# Compliance Frameworks
|
||||
# Source the framework listing from `bulk_compliance_frameworks.keys()`
|
||||
# so it is by construction a subset of what the bulk loader can resolve.
|
||||
# `get_available_compliance_frameworks(provider)` also discovers top-level
|
||||
# multi-provider universal JSONs (e.g. `prowler/compliance/csa_ccm_4.0.json`)
|
||||
# which `Compliance.get_bulk(provider)` does not load, and which the legacy
|
||||
# output handlers below cannot consume — using it as the source produced
|
||||
# Source the framework listing from the union of `bulk_compliance_frameworks`
|
||||
# and `universal_frameworks` so universal-only frameworks (e.g.
|
||||
# `prowler/compliance/csa_ccm_4.0.json`) — which `Compliance.get_bulk(provider)`
|
||||
# does not load — still reach `process_universal_compliance_frameworks` below.
|
||||
# The provider-specific block subtracts the names handled by the universal
|
||||
# processor so the legacy per-provider handlers only see frameworks that the
|
||||
# bulk loader actually resolved.
|
||||
input_compliance_frameworks = set(output_options.output_modes).intersection(
|
||||
bulk_compliance_frameworks.keys()
|
||||
set(bulk_compliance_frameworks.keys()) | set(universal_frameworks.keys())
|
||||
)
|
||||
|
||||
# ── Universal compliance frameworks (provider-agnostic) ──
|
||||
universal_processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks=input_compliance_frameworks,
|
||||
universal_frameworks=universal_frameworks,
|
||||
finding_outputs=finding_outputs,
|
||||
output_directory=output_options.output_directory,
|
||||
output_filename=output_options.output_filename,
|
||||
provider=provider,
|
||||
generated_outputs=generated_outputs,
|
||||
)
|
||||
input_compliance_frameworks -= universal_processed
|
||||
|
||||
if provider == "aws":
|
||||
for compliance_name in input_compliance_frameworks:
|
||||
if compliance_name.startswith("cis_"):
|
||||
@@ -648,6 +676,18 @@ def prowler():
|
||||
)
|
||||
generated_outputs["compliance"].append(cis)
|
||||
cis.batch_write_data_to_file()
|
||||
elif compliance_name.startswith("essential_eight"):
|
||||
filename = (
|
||||
f"{output_options.output_directory}/compliance/"
|
||||
f"{output_options.output_filename}_{compliance_name}.csv"
|
||||
)
|
||||
essential_eight = EssentialEightAWS(
|
||||
findings=finding_outputs,
|
||||
compliance=bulk_compliance_frameworks[compliance_name],
|
||||
file_path=filename,
|
||||
)
|
||||
generated_outputs["compliance"].append(essential_eight)
|
||||
essential_eight.batch_write_data_to_file()
|
||||
elif compliance_name == "mitre_attack_aws":
|
||||
# Generate MITRE ATT&CK Finding Object
|
||||
filename = (
|
||||
@@ -1402,6 +1442,9 @@ def prowler():
|
||||
output_options.output_filename,
|
||||
output_options.output_directory,
|
||||
compliance_overview,
|
||||
universal_frameworks=universal_frameworks,
|
||||
provider=provider,
|
||||
output_formats=args.output_formats,
|
||||
)
|
||||
if compliance_overview:
|
||||
print(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -87,8 +87,8 @@ def get_available_compliance_frameworks(provider=None):
|
||||
providers = [p.value for p in Provider]
|
||||
if provider:
|
||||
providers = [provider]
|
||||
for provider in providers:
|
||||
compliance_dir = f"{actual_directory}/../compliance/{provider}"
|
||||
for current_provider in providers:
|
||||
compliance_dir = f"{actual_directory}/../compliance/{current_provider}"
|
||||
if not os.path.isdir(compliance_dir):
|
||||
continue
|
||||
with os.scandir(compliance_dir) as files:
|
||||
@@ -97,7 +97,9 @@ def get_available_compliance_frameworks(provider=None):
|
||||
available_compliance_frameworks.append(
|
||||
file.name.removesuffix(".json")
|
||||
)
|
||||
# Also scan top-level compliance/ for multi-provider JSONs
|
||||
# Also scan top-level compliance/ for multi-provider (universal) JSONs.
|
||||
# When a specific provider was requested, only include the framework if it
|
||||
# declares support for that provider; otherwise include all universal frameworks.
|
||||
compliance_root = f"{actual_directory}/../compliance"
|
||||
if os.path.isdir(compliance_root):
|
||||
with os.scandir(compliance_root) as files:
|
||||
|
||||
@@ -299,12 +299,22 @@ def print_compliance_frameworks(
|
||||
def print_compliance_requirements(
|
||||
bulk_compliance_frameworks: dict, compliance_frameworks: list
|
||||
):
|
||||
from prowler.lib.check.compliance_models import ComplianceFramework
|
||||
|
||||
for compliance_framework in compliance_frameworks:
|
||||
for key in bulk_compliance_frameworks.keys():
|
||||
framework = bulk_compliance_frameworks[key].Framework
|
||||
provider = bulk_compliance_frameworks[key].Provider
|
||||
version = bulk_compliance_frameworks[key].Version
|
||||
requirements = bulk_compliance_frameworks[key].Requirements
|
||||
entry = bulk_compliance_frameworks[key]
|
||||
is_universal = isinstance(entry, ComplianceFramework)
|
||||
if is_universal:
|
||||
framework = entry.framework
|
||||
provider = entry.provider or "Multi-provider"
|
||||
version = entry.version
|
||||
requirements = entry.requirements
|
||||
else:
|
||||
framework = entry.Framework
|
||||
provider = entry.Provider or "Multi-provider"
|
||||
version = entry.Version
|
||||
requirements = entry.Requirements
|
||||
# We can list the compliance requirements for a given framework using the
|
||||
# bulk_compliance_frameworks keys since they are the compliance specification file name
|
||||
if compliance_framework == key:
|
||||
@@ -313,10 +323,23 @@ def print_compliance_requirements(
|
||||
)
|
||||
for requirement in requirements:
|
||||
checks = ""
|
||||
for check in requirement.Checks:
|
||||
checks += f" {Fore.YELLOW}\t\t{check}\n{Style.RESET_ALL}"
|
||||
if is_universal:
|
||||
req_checks = requirement.checks
|
||||
req_id = requirement.id
|
||||
req_description = requirement.description
|
||||
else:
|
||||
req_checks = requirement.Checks
|
||||
req_id = requirement.Id
|
||||
req_description = requirement.Description
|
||||
if isinstance(req_checks, dict):
|
||||
for prov, check_list in req_checks.items():
|
||||
for check in check_list:
|
||||
checks += f" {Fore.YELLOW}\t\t[{prov}] {check}\n{Style.RESET_ALL}"
|
||||
else:
|
||||
for check in req_checks:
|
||||
checks += f" {Fore.YELLOW}\t\t{check}\n{Style.RESET_ALL}"
|
||||
print(
|
||||
f"Requirement Id: {Fore.MAGENTA}{requirement.Id}{Style.RESET_ALL}\n\t- Description: {requirement.Description}\n\t- Checks:\n{checks}"
|
||||
f"Requirement Id: {Fore.MAGENTA}{req_id}{Style.RESET_ALL}\n\t- Description: {req_description}\n\t- Checks:\n{checks}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ def load_checks_to_execute(
|
||||
categories: set = None,
|
||||
resource_groups: set = None,
|
||||
list_checks: bool = False,
|
||||
universal_frameworks: dict = None,
|
||||
) -> set:
|
||||
"""Generate the list of checks to execute based on the cloud provider and the input arguments given"""
|
||||
try:
|
||||
@@ -155,12 +156,21 @@ def load_checks_to_execute(
|
||||
if not bulk_compliance_frameworks:
|
||||
bulk_compliance_frameworks = Compliance.get_bulk(provider=provider)
|
||||
for compliance_framework in compliance_frameworks:
|
||||
checks_to_execute.update(
|
||||
CheckMetadata.list(
|
||||
bulk_compliance_frameworks=bulk_compliance_frameworks,
|
||||
compliance_framework=compliance_framework,
|
||||
# Try universal frameworks first (snake_case dict-keyed checks)
|
||||
if (
|
||||
universal_frameworks
|
||||
and compliance_framework in universal_frameworks
|
||||
):
|
||||
fw = universal_frameworks[compliance_framework]
|
||||
for req in fw.requirements:
|
||||
checks_to_execute.update(req.checks.get(provider.lower(), []))
|
||||
elif compliance_framework in bulk_compliance_frameworks:
|
||||
checks_to_execute.update(
|
||||
CheckMetadata.list(
|
||||
bulk_compliance_frameworks=bulk_compliance_frameworks,
|
||||
compliance_framework=compliance_framework,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Handle if there are categories passed using --categories
|
||||
elif categories:
|
||||
|
||||
@@ -102,6 +102,48 @@ class CIS_Requirement_Attribute(BaseModel):
|
||||
References: str
|
||||
|
||||
|
||||
class EssentialEight_Requirement_Attribute_MaturityLevel(str, Enum):
|
||||
"""ASD Essential Eight Maturity Level"""
|
||||
|
||||
ML1 = "ML1"
|
||||
ML2 = "ML2"
|
||||
ML3 = "ML3"
|
||||
|
||||
|
||||
class EssentialEight_Requirement_Attribute_AssessmentStatus(str, Enum):
|
||||
"""Essential Eight Requirement Attribute Assessment Status"""
|
||||
|
||||
Manual = "Manual"
|
||||
Automated = "Automated"
|
||||
|
||||
|
||||
class EssentialEight_Requirement_Attribute_CloudApplicability(str, Enum):
|
||||
"""How well the ASD control maps to AWS cloud infrastructure."""
|
||||
|
||||
Full = "full"
|
||||
Partial = "partial"
|
||||
Limited = "limited"
|
||||
NonApplicable = "non-applicable"
|
||||
|
||||
|
||||
# Essential Eight Requirement Attribute
|
||||
class EssentialEight_Requirement_Attribute(BaseModel):
|
||||
"""ASD Essential Eight Requirement Attribute"""
|
||||
|
||||
Section: str
|
||||
MaturityLevel: EssentialEight_Requirement_Attribute_MaturityLevel
|
||||
AssessmentStatus: EssentialEight_Requirement_Attribute_AssessmentStatus
|
||||
CloudApplicability: EssentialEight_Requirement_Attribute_CloudApplicability
|
||||
MitigatedThreats: list[str]
|
||||
Description: str
|
||||
RationaleStatement: str
|
||||
ImpactStatement: str
|
||||
RemediationProcedure: str
|
||||
AuditProcedure: str
|
||||
AdditionalInformation: str
|
||||
References: str
|
||||
|
||||
|
||||
# Well Architected Requirement Attribute
|
||||
class AWS_Well_Architected_Requirement_Attribute(BaseModel):
|
||||
"""AWS Well Architected Requirement Attribute"""
|
||||
@@ -250,6 +292,7 @@ class Compliance_Requirement(BaseModel):
|
||||
Name: Optional[str] = None
|
||||
Attributes: list[
|
||||
Union[
|
||||
EssentialEight_Requirement_Attribute,
|
||||
CIS_Requirement_Attribute,
|
||||
ENS_Requirement_Attribute,
|
||||
ISO27001_2013_Requirement_Attribute,
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import sys
|
||||
|
||||
from prowler.lib.check.models import Check_Report
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.outputs.compliance.c5.c5 import get_c5_table
|
||||
from prowler.lib.outputs.compliance.ccc.ccc import get_ccc_table
|
||||
from prowler.lib.outputs.compliance.cis.cis import get_cis_table
|
||||
from prowler.lib.outputs.compliance.compliance_check import ( # noqa: F401 - re-export for backward compatibility
|
||||
get_check_compliance,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.csa.csa import get_csa_table
|
||||
from prowler.lib.outputs.compliance.ens.ens import get_ens_table
|
||||
from prowler.lib.outputs.compliance.essential_eight.essential_eight import (
|
||||
get_essential_eight_table,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.generic.generic_table import (
|
||||
get_generic_compliance_table,
|
||||
)
|
||||
@@ -17,6 +22,94 @@ from prowler.lib.outputs.compliance.mitre_attack.mitre_attack import (
|
||||
from prowler.lib.outputs.compliance.prowler_threatscore.prowler_threatscore import (
|
||||
get_prowler_threatscore_table,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.universal.universal_table import get_universal_table
|
||||
|
||||
|
||||
def process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks: set,
|
||||
universal_frameworks: dict,
|
||||
finding_outputs: list,
|
||||
output_directory: str,
|
||||
output_filename: str,
|
||||
provider: str,
|
||||
generated_outputs: dict,
|
||||
) -> set:
|
||||
"""Process universal compliance frameworks, generating CSV and OCSF outputs.
|
||||
|
||||
For each framework in *input_compliance_frameworks* that exists in
|
||||
*universal_frameworks* and has an outputs.table_config, this function
|
||||
creates both a CSV (UniversalComplianceOutput) and an OCSF JSON
|
||||
(OCSFComplianceOutput) file. OCSF is always generated regardless of
|
||||
the user's ``--output-formats`` flag.
|
||||
|
||||
The function is idempotent: it tracks already-created writers via
|
||||
``generated_outputs["compliance"]`` keyed by ``file_path``. If invoked
|
||||
again for the same framework (e.g. once per streaming batch), it
|
||||
reuses the existing writer instead of recreating it. This guarantees
|
||||
one output writer per framework for the whole execution and keeps
|
||||
the OCSF JSON array valid across multiple calls.
|
||||
|
||||
Returns the set of framework names that were processed so the caller
|
||||
can remove them before entering the legacy per-provider output loop.
|
||||
"""
|
||||
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
|
||||
OCSFComplianceOutput,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.universal.universal_output import (
|
||||
UniversalComplianceOutput,
|
||||
)
|
||||
|
||||
existing_writers = {
|
||||
getattr(out, "file_path", None): out
|
||||
for out in generated_outputs.get("compliance", [])
|
||||
if isinstance(out, (UniversalComplianceOutput, OCSFComplianceOutput))
|
||||
}
|
||||
|
||||
processed = set()
|
||||
for compliance_name in input_compliance_frameworks:
|
||||
if not (
|
||||
compliance_name in universal_frameworks
|
||||
and universal_frameworks[compliance_name].outputs
|
||||
and universal_frameworks[compliance_name].outputs.table_config
|
||||
):
|
||||
continue
|
||||
|
||||
fw = universal_frameworks[compliance_name]
|
||||
|
||||
# CSV output
|
||||
csv_path = (
|
||||
f"{output_directory}/compliance/" f"{output_filename}_{compliance_name}.csv"
|
||||
)
|
||||
if csv_path not in existing_writers:
|
||||
output = UniversalComplianceOutput(
|
||||
findings=finding_outputs,
|
||||
framework=fw,
|
||||
file_path=csv_path,
|
||||
provider=provider,
|
||||
)
|
||||
generated_outputs["compliance"].append(output)
|
||||
existing_writers[csv_path] = output
|
||||
output.batch_write_data_to_file()
|
||||
|
||||
# OCSF output (always generated for universal frameworks)
|
||||
ocsf_path = (
|
||||
f"{output_directory}/compliance/"
|
||||
f"{output_filename}_{compliance_name}.ocsf.json"
|
||||
)
|
||||
if ocsf_path not in existing_writers:
|
||||
ocsf_output = OCSFComplianceOutput(
|
||||
findings=finding_outputs,
|
||||
framework=fw,
|
||||
file_path=ocsf_path,
|
||||
provider=provider,
|
||||
)
|
||||
generated_outputs["compliance"].append(ocsf_output)
|
||||
existing_writers[ocsf_path] = ocsf_output
|
||||
ocsf_output.batch_write_data_to_file()
|
||||
|
||||
processed.add(compliance_name)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def display_compliance_table(
|
||||
@@ -26,6 +119,9 @@ def display_compliance_table(
|
||||
output_filename: str,
|
||||
output_directory: str,
|
||||
compliance_overview: bool,
|
||||
universal_frameworks: dict = None,
|
||||
provider: str = None,
|
||||
output_formats: list = None,
|
||||
) -> None:
|
||||
"""
|
||||
display_compliance_table generates the compliance table for the given compliance framework.
|
||||
@@ -37,6 +133,9 @@ def display_compliance_table(
|
||||
output_filename (str): The output filename
|
||||
output_directory (str): The output directory
|
||||
compliance_overview (bool): The compliance
|
||||
universal_frameworks (dict): Optional universal ComplianceFramework objects
|
||||
provider (str): The current provider (e.g. "aws") for multi-provider filtering
|
||||
output_formats (list): The output formats to generate
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -45,16 +144,24 @@ def display_compliance_table(
|
||||
findings = [f for f in findings if f.check_metadata.CheckID in bulk_checks_metadata]
|
||||
|
||||
try:
|
||||
if "ens_" in compliance_framework:
|
||||
get_ens_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
compliance_framework,
|
||||
output_filename,
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "cis_" in compliance_framework:
|
||||
# Universal path: if the framework has TableConfig, use the universal renderer
|
||||
if universal_frameworks and compliance_framework in universal_frameworks:
|
||||
fw = universal_frameworks[compliance_framework]
|
||||
if fw.outputs and fw.outputs.table_config:
|
||||
get_universal_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
compliance_framework,
|
||||
output_filename,
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
framework=fw,
|
||||
provider=provider,
|
||||
output_formats=output_formats,
|
||||
)
|
||||
return
|
||||
|
||||
if compliance_framework.startswith("cis_"):
|
||||
get_cis_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
@@ -63,7 +170,16 @@ def display_compliance_table(
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "mitre_attack" in compliance_framework:
|
||||
elif compliance_framework.startswith("ens_"):
|
||||
get_ens_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
compliance_framework,
|
||||
output_filename,
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif compliance_framework.startswith("mitre_attack"):
|
||||
get_mitre_attack_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
@@ -72,7 +188,7 @@ def display_compliance_table(
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "kisa_isms_" in compliance_framework:
|
||||
elif compliance_framework.startswith("kisa"):
|
||||
get_kisa_ismsp_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
@@ -81,7 +197,7 @@ def display_compliance_table(
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "threatscore_" in compliance_framework:
|
||||
elif compliance_framework.startswith("prowler_threatscore_"):
|
||||
get_prowler_threatscore_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
@@ -90,7 +206,7 @@ def display_compliance_table(
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "csa_ccm_" in compliance_framework:
|
||||
elif compliance_framework.startswith("csa_ccm_"):
|
||||
get_csa_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
@@ -99,7 +215,7 @@ def display_compliance_table(
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "c5_" in compliance_framework:
|
||||
elif compliance_framework.startswith("c5_"):
|
||||
get_c5_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
@@ -117,6 +233,15 @@ def display_compliance_table(
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
elif "essential_eight" in compliance_framework:
|
||||
get_essential_eight_table(
|
||||
findings,
|
||||
bulk_checks_metadata,
|
||||
compliance_framework,
|
||||
output_filename,
|
||||
output_directory,
|
||||
compliance_overview,
|
||||
)
|
||||
else:
|
||||
get_generic_compliance_table(
|
||||
findings,
|
||||
@@ -131,49 +256,3 @@ def display_compliance_table(
|
||||
f"{error.__class__.__name__}:{error.__traceback__.tb_lineno} -- {error}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# TODO: this should be in the Check class
|
||||
def get_check_compliance(
|
||||
finding: Check_Report, provider_type: str, bulk_checks_metadata: dict
|
||||
) -> dict:
|
||||
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
|
||||
|
||||
Example:
|
||||
|
||||
{
|
||||
"CIS-1.4": ["2.1.3"],
|
||||
"CIS-1.5": ["2.1.3"],
|
||||
}
|
||||
|
||||
Args:
|
||||
finding (Any): The Check_Report finding
|
||||
provider_type (str): The provider type
|
||||
bulk_checks_metadata (dict): The bulk checks metadata
|
||||
|
||||
Returns:
|
||||
dict: The compliance framework as key and the requirements where the finding's check is present.
|
||||
"""
|
||||
try:
|
||||
check_compliance = {}
|
||||
# We have to retrieve all the check's compliance requirements
|
||||
if finding.check_metadata.CheckID in bulk_checks_metadata:
|
||||
for compliance in bulk_checks_metadata[
|
||||
finding.check_metadata.CheckID
|
||||
].Compliance:
|
||||
compliance_fw = compliance.Framework
|
||||
if compliance.Version:
|
||||
compliance_fw = f"{compliance_fw}-{compliance.Version}"
|
||||
# compliance.Provider == "Azure" or "Kubernetes"
|
||||
# provider_type == "azure" or "kubernetes"
|
||||
if compliance.Provider.upper() == provider_type.upper():
|
||||
if compliance_fw not in check_compliance:
|
||||
check_compliance[compliance_fw] = []
|
||||
for requirement in compliance.Requirements:
|
||||
check_compliance[compliance_fw].append(requirement.Id)
|
||||
return check_compliance
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
|
||||
)
|
||||
return {}
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from prowler.lib.check.models import Check_Report
|
||||
from prowler.lib.logger import logger
|
||||
|
||||
|
||||
# TODO: this should be in the Check class
|
||||
def get_check_compliance(
|
||||
finding: Check_Report, provider_type: str, bulk_checks_metadata: dict
|
||||
) -> dict:
|
||||
"""get_check_compliance returns a map with the compliance framework as key and the requirements where the finding's check is present.
|
||||
|
||||
Example:
|
||||
|
||||
{
|
||||
"CIS-1.4": ["2.1.3"],
|
||||
"CIS-1.5": ["2.1.3"],
|
||||
}
|
||||
|
||||
Args:
|
||||
finding (Any): The Check_Report finding
|
||||
provider_type (str): The provider type
|
||||
bulk_checks_metadata (dict): The bulk checks metadata
|
||||
|
||||
Returns:
|
||||
dict: The compliance framework as key and the requirements where the finding's check is present.
|
||||
"""
|
||||
try:
|
||||
check_compliance = {}
|
||||
# We have to retrieve all the check's compliance requirements
|
||||
if finding.check_metadata.CheckID in bulk_checks_metadata:
|
||||
for compliance in bulk_checks_metadata[
|
||||
finding.check_metadata.CheckID
|
||||
].Compliance:
|
||||
compliance_fw = compliance.Framework
|
||||
if compliance.Version:
|
||||
compliance_fw = f"{compliance_fw}-{compliance.Version}"
|
||||
# compliance.Provider == "Azure" or "Kubernetes"
|
||||
# provider_type == "azure" or "kubernetes"
|
||||
if compliance.Provider.upper() == provider_type.upper():
|
||||
if compliance_fw not in check_compliance:
|
||||
check_compliance[compliance_fw] = []
|
||||
for requirement in compliance.Requirements:
|
||||
check_compliance[compliance_fw].append(requirement.Id)
|
||||
return check_compliance
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}] -- {error}"
|
||||
)
|
||||
return {}
|
||||
@@ -0,0 +1,98 @@
|
||||
from colorama import Fore, Style
|
||||
from tabulate import tabulate
|
||||
|
||||
from prowler.config.config import orange_color
|
||||
|
||||
|
||||
def get_essential_eight_table(
|
||||
findings: list,
|
||||
bulk_checks_metadata: dict,
|
||||
compliance_framework: str,
|
||||
output_filename: str,
|
||||
output_directory: str,
|
||||
compliance_overview: bool,
|
||||
):
|
||||
sections = {}
|
||||
essential_eight_compliance_table = {
|
||||
"Provider": [],
|
||||
"Section": [],
|
||||
"Status": [],
|
||||
"Muted": [],
|
||||
}
|
||||
pass_count = []
|
||||
fail_count = []
|
||||
muted_count = []
|
||||
for index, finding in enumerate(findings):
|
||||
check = bulk_checks_metadata[finding.check_metadata.CheckID]
|
||||
check_compliances = check.Compliance
|
||||
for compliance in check_compliances:
|
||||
if compliance.Framework == "Essential-Eight":
|
||||
for requirement in compliance.Requirements:
|
||||
for attribute in requirement.Attributes:
|
||||
section = attribute.Section
|
||||
if section not in sections:
|
||||
sections[section] = {
|
||||
"FAIL": 0,
|
||||
"PASS": 0,
|
||||
"Muted": 0,
|
||||
}
|
||||
if finding.muted:
|
||||
if index not in muted_count:
|
||||
muted_count.append(index)
|
||||
sections[section]["Muted"] += 1
|
||||
else:
|
||||
if finding.status == "FAIL" and index not in fail_count:
|
||||
fail_count.append(index)
|
||||
sections[section]["FAIL"] += 1
|
||||
elif finding.status == "PASS" and index not in pass_count:
|
||||
pass_count.append(index)
|
||||
sections[section]["PASS"] += 1
|
||||
|
||||
sections = dict(sorted(sections.items()))
|
||||
for section in sections:
|
||||
essential_eight_compliance_table["Provider"].append(compliance.Provider)
|
||||
essential_eight_compliance_table["Section"].append(section)
|
||||
if sections[section]["FAIL"] > 0:
|
||||
essential_eight_compliance_table["Status"].append(
|
||||
f"{Fore.RED}FAIL({sections[section]['FAIL']}){Style.RESET_ALL}"
|
||||
)
|
||||
elif sections[section]["PASS"] > 0:
|
||||
essential_eight_compliance_table["Status"].append(
|
||||
f"{Fore.GREEN}PASS({sections[section]['PASS']}){Style.RESET_ALL}"
|
||||
)
|
||||
else:
|
||||
essential_eight_compliance_table["Status"].append("-")
|
||||
essential_eight_compliance_table["Muted"].append(
|
||||
f"{orange_color}{sections[section]['Muted']}{Style.RESET_ALL}"
|
||||
)
|
||||
if len(fail_count) + len(pass_count) + len(muted_count) > 1:
|
||||
print(
|
||||
f"\nCompliance Status of {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Framework:"
|
||||
)
|
||||
total_findings_count = len(fail_count) + len(pass_count) + len(muted_count)
|
||||
overview_table = [
|
||||
[
|
||||
f"{Fore.RED}{round(len(fail_count) / total_findings_count * 100, 2)}% ({len(fail_count)}) FAIL{Style.RESET_ALL}",
|
||||
f"{Fore.GREEN}{round(len(pass_count) / total_findings_count * 100, 2)}% ({len(pass_count)}) PASS{Style.RESET_ALL}",
|
||||
f"{orange_color}{round(len(muted_count) / total_findings_count * 100, 2)}% ({len(muted_count)}) MUTED{Style.RESET_ALL}",
|
||||
]
|
||||
]
|
||||
print(tabulate(overview_table, tablefmt="rounded_grid"))
|
||||
if not compliance_overview:
|
||||
print(
|
||||
f"\nFramework {Fore.YELLOW}{compliance_framework.upper()}{Style.RESET_ALL} Results:"
|
||||
)
|
||||
print(
|
||||
tabulate(
|
||||
essential_eight_compliance_table,
|
||||
headers="keys",
|
||||
tablefmt="rounded_grid",
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"{Style.BRIGHT}* Only sections containing results appear.{Style.RESET_ALL}"
|
||||
)
|
||||
print(f"\nDetailed results of {compliance_framework.upper()} are in:")
|
||||
print(
|
||||
f" - CSV: {output_directory}/compliance/{output_filename}_{compliance_framework}.csv\n"
|
||||
)
|
||||
@@ -0,0 +1,111 @@
|
||||
from prowler.config.config import timestamp
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
from prowler.lib.outputs.compliance.compliance_output import ComplianceOutput
|
||||
from prowler.lib.outputs.compliance.essential_eight.models import (
|
||||
EssentialEightAWSModel,
|
||||
)
|
||||
from prowler.lib.outputs.finding import Finding
|
||||
|
||||
|
||||
class EssentialEightAWS(ComplianceOutput):
|
||||
"""
|
||||
This class represents the AWS ASD Essential Eight compliance output.
|
||||
|
||||
Attributes:
|
||||
- _data (list): A list to store transformed data from findings.
|
||||
- _file_descriptor (TextIOWrapper): A file descriptor to write data to a file.
|
||||
|
||||
Methods:
|
||||
- transform: Transforms findings into AWS Essential Eight compliance format.
|
||||
"""
|
||||
|
||||
def transform(
|
||||
self,
|
||||
findings: list[Finding],
|
||||
compliance: Compliance,
|
||||
compliance_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Transforms a list of findings into AWS Essential Eight compliance format.
|
||||
|
||||
Parameters:
|
||||
- findings (list): A list of findings.
|
||||
- compliance (Compliance): A compliance model.
|
||||
- compliance_name (str): The name of the compliance model.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
for finding in findings:
|
||||
finding_requirements = finding.compliance.get(compliance_name, [])
|
||||
for requirement in compliance.Requirements:
|
||||
if requirement.Id in finding_requirements:
|
||||
for attribute in requirement.Attributes:
|
||||
compliance_row = EssentialEightAWSModel(
|
||||
Provider=finding.provider,
|
||||
Description=compliance.Description,
|
||||
AccountId=finding.account_uid,
|
||||
Region=finding.region,
|
||||
AssessmentDate=str(timestamp),
|
||||
Requirements_Id=requirement.Id,
|
||||
Requirements_Description=requirement.Description,
|
||||
Requirements_Attributes_Section=attribute.Section,
|
||||
Requirements_Attributes_MaturityLevel=attribute.MaturityLevel,
|
||||
Requirements_Attributes_AssessmentStatus=attribute.AssessmentStatus,
|
||||
Requirements_Attributes_CloudApplicability=attribute.CloudApplicability,
|
||||
Requirements_Attributes_MitigatedThreats=", ".join(
|
||||
attribute.MitigatedThreats
|
||||
),
|
||||
Requirements_Attributes_Description=attribute.Description,
|
||||
Requirements_Attributes_RationaleStatement=attribute.RationaleStatement,
|
||||
Requirements_Attributes_ImpactStatement=attribute.ImpactStatement,
|
||||
Requirements_Attributes_RemediationProcedure=attribute.RemediationProcedure,
|
||||
Requirements_Attributes_AuditProcedure=attribute.AuditProcedure,
|
||||
Requirements_Attributes_AdditionalInformation=attribute.AdditionalInformation,
|
||||
Requirements_Attributes_References=attribute.References,
|
||||
Status=finding.status,
|
||||
StatusExtended=finding.status_extended,
|
||||
ResourceId=finding.resource_uid,
|
||||
ResourceName=finding.resource_name,
|
||||
CheckId=finding.check_id,
|
||||
Muted=finding.muted,
|
||||
Framework=compliance.Framework,
|
||||
Name=compliance.Name,
|
||||
)
|
||||
self._data.append(compliance_row)
|
||||
# Add manual requirements to the compliance output
|
||||
for requirement in compliance.Requirements:
|
||||
if not requirement.Checks:
|
||||
for attribute in requirement.Attributes:
|
||||
compliance_row = EssentialEightAWSModel(
|
||||
Provider=compliance.Provider.lower(),
|
||||
Description=compliance.Description,
|
||||
AccountId="",
|
||||
Region="",
|
||||
AssessmentDate=str(timestamp),
|
||||
Requirements_Id=requirement.Id,
|
||||
Requirements_Description=requirement.Description,
|
||||
Requirements_Attributes_Section=attribute.Section,
|
||||
Requirements_Attributes_MaturityLevel=attribute.MaturityLevel,
|
||||
Requirements_Attributes_AssessmentStatus=attribute.AssessmentStatus,
|
||||
Requirements_Attributes_CloudApplicability=attribute.CloudApplicability,
|
||||
Requirements_Attributes_MitigatedThreats=", ".join(
|
||||
attribute.MitigatedThreats
|
||||
),
|
||||
Requirements_Attributes_Description=attribute.Description,
|
||||
Requirements_Attributes_RationaleStatement=attribute.RationaleStatement,
|
||||
Requirements_Attributes_ImpactStatement=attribute.ImpactStatement,
|
||||
Requirements_Attributes_RemediationProcedure=attribute.RemediationProcedure,
|
||||
Requirements_Attributes_AuditProcedure=attribute.AuditProcedure,
|
||||
Requirements_Attributes_AdditionalInformation=attribute.AdditionalInformation,
|
||||
Requirements_Attributes_References=attribute.References,
|
||||
Status="MANUAL",
|
||||
StatusExtended="Manual check",
|
||||
ResourceId="manual_check",
|
||||
ResourceName="Manual check",
|
||||
CheckId="manual",
|
||||
Muted=False,
|
||||
Framework=compliance.Framework,
|
||||
Name=compliance.Name,
|
||||
)
|
||||
self._data.append(compliance_row)
|
||||
@@ -0,0 +1,35 @@
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
|
||||
class EssentialEightAWSModel(BaseModel):
|
||||
"""
|
||||
EssentialEightAWSModel generates a finding's output in AWS ASD Essential Eight Compliance format.
|
||||
"""
|
||||
|
||||
Provider: str
|
||||
Description: str
|
||||
AccountId: str
|
||||
Region: str
|
||||
AssessmentDate: str
|
||||
Requirements_Id: str
|
||||
Requirements_Description: str
|
||||
Requirements_Attributes_Section: str
|
||||
Requirements_Attributes_MaturityLevel: str
|
||||
Requirements_Attributes_AssessmentStatus: str
|
||||
Requirements_Attributes_CloudApplicability: str
|
||||
Requirements_Attributes_MitigatedThreats: str
|
||||
Requirements_Attributes_Description: str
|
||||
Requirements_Attributes_RationaleStatement: str
|
||||
Requirements_Attributes_ImpactStatement: str
|
||||
Requirements_Attributes_RemediationProcedure: str
|
||||
Requirements_Attributes_AuditProcedure: str
|
||||
Requirements_Attributes_AdditionalInformation: str
|
||||
Requirements_Attributes_References: str
|
||||
Status: str
|
||||
StatusExtended: str
|
||||
ResourceId: str
|
||||
ResourceName: str
|
||||
CheckId: str
|
||||
Muted: bool
|
||||
Framework: str
|
||||
Name: str
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from py_ocsf_models.events.base_event import SeverityID
|
||||
from py_ocsf_models.events.base_event import StatusID as EventStatusID
|
||||
@@ -20,11 +21,12 @@ from py_ocsf_models.objects.resource_details import ResourceDetails
|
||||
from prowler.config.config import prowler_version
|
||||
from prowler.lib.check.compliance_models import ComplianceFramework
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.outputs.finding import Finding
|
||||
from prowler.lib.outputs.ocsf.ocsf import OCSF
|
||||
from prowler.lib.outputs.utils import unroll_dict_to_list
|
||||
from prowler.lib.utils.utils import open_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from prowler.lib.outputs.finding import Finding
|
||||
|
||||
PROWLER_TO_COMPLIANCE_STATUS = {
|
||||
"PASS": ComplianceStatusID.Pass,
|
||||
"FAIL": ComplianceStatusID.Fail,
|
||||
@@ -32,6 +34,40 @@ PROWLER_TO_COMPLIANCE_STATUS = {
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_resource_data(resource_details, resource_metadata) -> dict:
|
||||
"""Ensure resource data is JSON-serializable.
|
||||
|
||||
Service resource_metadata may carry non-serializable objects (e.g. raw
|
||||
Pydantic models or service classes such as ``Trail`` / ``LifecyclePolicy``).
|
||||
Convert them to plain dicts and roundtrip through JSON so the resulting
|
||||
ComplianceFinding can be serialized without errors.
|
||||
"""
|
||||
|
||||
def _make_serializable(obj):
|
||||
if hasattr(obj, "model_dump") and callable(obj.model_dump):
|
||||
return _make_serializable(obj.model_dump())
|
||||
if hasattr(obj, "dict") and callable(obj.dict):
|
||||
return _make_serializable(obj.dict())
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
return obj
|
||||
|
||||
try:
|
||||
converted = _make_serializable(resource_metadata)
|
||||
sanitized_metadata = json.loads(json.dumps(converted, default=str))
|
||||
except (TypeError, ValueError, RecursionError) as error:
|
||||
logger.warning(
|
||||
f"Failed to serialize resource metadata, defaulting to empty: {error}"
|
||||
)
|
||||
sanitized_metadata = {}
|
||||
return {
|
||||
"details": resource_details,
|
||||
"metadata": sanitized_metadata,
|
||||
}
|
||||
|
||||
|
||||
def _to_snake_case(name: str) -> str:
|
||||
"""Convert a PascalCase or camelCase string to snake_case."""
|
||||
import re
|
||||
@@ -108,7 +144,7 @@ class OCSFComplianceOutput:
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
findings: List[Finding],
|
||||
findings: List["Finding"],
|
||||
framework: ComplianceFramework,
|
||||
compliance_name: str,
|
||||
) -> None:
|
||||
@@ -177,7 +213,7 @@ class OCSFComplianceOutput:
|
||||
|
||||
def _build_compliance_finding(
|
||||
self,
|
||||
finding: Finding,
|
||||
finding: "Finding",
|
||||
framework: ComplianceFramework,
|
||||
requirement,
|
||||
compliance_name: str,
|
||||
@@ -195,7 +231,9 @@ class OCSFComplianceOutput:
|
||||
finding.metadata.Severity.capitalize(),
|
||||
SeverityID.Unknown,
|
||||
)
|
||||
event_status = OCSF.get_finding_status_id(finding.muted)
|
||||
event_status = (
|
||||
EventStatusID.Suppressed if finding.muted else EventStatusID.New
|
||||
)
|
||||
|
||||
time_value = (
|
||||
int(finding.timestamp.timestamp())
|
||||
@@ -268,10 +306,10 @@ class OCSFComplianceOutput:
|
||||
if finding.provider == "kubernetes"
|
||||
else None
|
||||
),
|
||||
data={
|
||||
"details": finding.resource_details,
|
||||
"metadata": finding.resource_metadata,
|
||||
},
|
||||
data=_sanitize_resource_data(
|
||||
finding.resource_details,
|
||||
finding.resource_metadata,
|
||||
),
|
||||
)
|
||||
],
|
||||
severity_id=finding_severity.value,
|
||||
|
||||
@@ -0,0 +1,294 @@
|
||||
from csv import DictWriter
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic.v1 import create_model
|
||||
|
||||
from prowler.config.config import timestamp
|
||||
from prowler.lib.check.compliance_models import ComplianceFramework
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.utils.utils import open_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from prowler.lib.outputs.finding import Finding
|
||||
|
||||
PROVIDER_HEADER_MAP = {
|
||||
"aws": ("AccountId", "account_uid", "Region", "region"),
|
||||
"azure": ("SubscriptionId", "account_uid", "Location", "region"),
|
||||
"gcp": ("ProjectId", "account_uid", "Location", "region"),
|
||||
"kubernetes": ("Context", "account_name", "Namespace", "region"),
|
||||
"m365": ("TenantId", "account_uid", "Location", "region"),
|
||||
"github": ("Account_Name", "account_name", "Account_Id", "account_uid"),
|
||||
"oraclecloud": ("TenancyId", "account_uid", "Region", "region"),
|
||||
"alibabacloud": ("AccountId", "account_uid", "Region", "region"),
|
||||
"nhn": ("AccountId", "account_uid", "Region", "region"),
|
||||
}
|
||||
_DEFAULT_HEADERS = ("AccountId", "account_uid", "Region", "region")
|
||||
|
||||
|
||||
class UniversalComplianceOutput:
|
||||
"""Universal compliance CSV output driven by ComplianceFramework metadata.
|
||||
|
||||
Dynamically builds a Pydantic row model from attributes_metadata so that
|
||||
CSV columns match the framework's declared attribute fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
findings: list,
|
||||
framework: ComplianceFramework,
|
||||
file_path: str = None,
|
||||
from_cli: bool = True,
|
||||
provider: str = None,
|
||||
) -> None:
|
||||
self._data = []
|
||||
self._file_descriptor = None
|
||||
self.file_path = file_path
|
||||
self._from_cli = from_cli
|
||||
self._provider = provider
|
||||
self.close_file = False
|
||||
|
||||
if file_path:
|
||||
path_obj = Path(file_path)
|
||||
self._file_extension = path_obj.suffix if path_obj.suffix else ""
|
||||
|
||||
if findings:
|
||||
self._row_model = self._build_row_model(framework)
|
||||
compliance_name = (
|
||||
framework.framework + "-" + framework.version
|
||||
if framework.version
|
||||
else framework.framework
|
||||
)
|
||||
self._transform(findings, framework, compliance_name)
|
||||
if not self._file_descriptor and file_path:
|
||||
self._create_file_descriptor(file_path)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
def _build_row_model(self, framework: ComplianceFramework):
|
||||
"""Build a dynamic Pydantic model from attributes_metadata."""
|
||||
acct_header, acct_field, loc_header, loc_field = PROVIDER_HEADER_MAP.get(
|
||||
(self._provider or "").lower(), _DEFAULT_HEADERS
|
||||
)
|
||||
self._acct_header = acct_header
|
||||
self._acct_field = acct_field
|
||||
self._loc_header = loc_header
|
||||
self._loc_field = loc_field
|
||||
|
||||
# Base fields present in every compliance CSV
|
||||
fields = {
|
||||
"Provider": (str, ...),
|
||||
"Description": (str, ...),
|
||||
acct_header: (str, ...),
|
||||
loc_header: (str, ...),
|
||||
"AssessmentDate": (str, ...),
|
||||
"Requirements_Id": (str, ...),
|
||||
"Requirements_Description": (str, ...),
|
||||
}
|
||||
|
||||
# Dynamic attribute columns from metadata
|
||||
if framework.attributes_metadata:
|
||||
for attr_meta in framework.attributes_metadata:
|
||||
if not attr_meta.output_formats.csv:
|
||||
continue
|
||||
field_name = f"Requirements_Attributes_{attr_meta.key}"
|
||||
# Map type strings to Python types
|
||||
type_map = {
|
||||
"str": Optional[str],
|
||||
"int": Optional[int],
|
||||
"float": Optional[float],
|
||||
"bool": Optional[bool],
|
||||
"list_str": Optional[str], # Serialized as joined string
|
||||
"list_dict": Optional[str], # Serialized as string
|
||||
}
|
||||
py_type = type_map.get(attr_meta.type, Optional[str])
|
||||
fields[field_name] = (py_type, None)
|
||||
|
||||
# Check if any requirement has MITRE fields
|
||||
has_mitre = any(req.tactics for req in framework.requirements if req.tactics)
|
||||
if has_mitre:
|
||||
fields["Requirements_Tactics"] = (Optional[str], None)
|
||||
fields["Requirements_SubTechniques"] = (Optional[str], None)
|
||||
fields["Requirements_Platforms"] = (Optional[str], None)
|
||||
fields["Requirements_TechniqueURL"] = (Optional[str], None)
|
||||
|
||||
# Trailing fields
|
||||
fields["Status"] = (str, ...)
|
||||
fields["StatusExtended"] = (str, ...)
|
||||
fields["ResourceId"] = (str, ...)
|
||||
fields["ResourceName"] = (str, ...)
|
||||
fields["CheckId"] = (str, ...)
|
||||
fields["Muted"] = (bool, ...)
|
||||
fields["Framework"] = (str, ...)
|
||||
fields["Name"] = (str, ...)
|
||||
|
||||
return create_model("UniversalComplianceRow", **fields)
|
||||
|
||||
def _serialize_attr_value(self, value):
|
||||
"""Serialize attribute values for CSV."""
|
||||
if isinstance(value, list):
|
||||
if value and isinstance(value[0], dict):
|
||||
return str(value)
|
||||
return " | ".join(str(v) for v in value)
|
||||
return value
|
||||
|
||||
def _build_row(self, finding, framework, requirement, is_manual=False):
|
||||
"""Build a single row dict for a finding + requirement combination."""
|
||||
row = {
|
||||
"Provider": (
|
||||
finding.provider
|
||||
if not is_manual
|
||||
else (framework.provider or self._provider or "").lower()
|
||||
),
|
||||
"Description": framework.description,
|
||||
self._acct_header: (
|
||||
getattr(finding, self._acct_field, "") if not is_manual else ""
|
||||
),
|
||||
self._loc_header: (
|
||||
getattr(finding, self._loc_field, "") if not is_manual else ""
|
||||
),
|
||||
"AssessmentDate": str(timestamp),
|
||||
"Requirements_Id": requirement.id,
|
||||
"Requirements_Description": requirement.description,
|
||||
}
|
||||
|
||||
# Add dynamic attribute columns
|
||||
if framework.attributes_metadata:
|
||||
for attr_meta in framework.attributes_metadata:
|
||||
if not attr_meta.output_formats.csv:
|
||||
continue
|
||||
field_name = f"Requirements_Attributes_{attr_meta.key}"
|
||||
raw_val = requirement.attributes.get(attr_meta.key)
|
||||
row[field_name] = (
|
||||
self._serialize_attr_value(raw_val) if raw_val is not None else None
|
||||
)
|
||||
|
||||
# MITRE fields
|
||||
if requirement.tactics:
|
||||
row["Requirements_Tactics"] = (
|
||||
" | ".join(requirement.tactics) if requirement.tactics else None
|
||||
)
|
||||
row["Requirements_SubTechniques"] = (
|
||||
" | ".join(requirement.sub_techniques)
|
||||
if requirement.sub_techniques
|
||||
else None
|
||||
)
|
||||
row["Requirements_Platforms"] = (
|
||||
" | ".join(requirement.platforms) if requirement.platforms else None
|
||||
)
|
||||
row["Requirements_TechniqueURL"] = requirement.technique_url
|
||||
|
||||
row["Status"] = finding.status if not is_manual else "MANUAL"
|
||||
row["StatusExtended"] = (
|
||||
finding.status_extended if not is_manual else "Manual check"
|
||||
)
|
||||
row["ResourceId"] = finding.resource_uid if not is_manual else "manual_check"
|
||||
row["ResourceName"] = finding.resource_name if not is_manual else "Manual check"
|
||||
row["CheckId"] = finding.check_id if not is_manual else "manual"
|
||||
row["Muted"] = finding.muted if not is_manual else False
|
||||
row["Framework"] = framework.framework
|
||||
row["Name"] = framework.name
|
||||
|
||||
return row
|
||||
|
||||
def _transform(
|
||||
self,
|
||||
findings: list["Finding"],
|
||||
framework: ComplianceFramework,
|
||||
compliance_name: str,
|
||||
) -> None:
|
||||
"""Transform findings into universal compliance CSV rows."""
|
||||
# Build check -> requirements map (filtered by provider for dict checks)
|
||||
check_req_map = {}
|
||||
for req in framework.requirements:
|
||||
checks = req.checks
|
||||
if self._provider:
|
||||
all_checks = checks.get(self._provider.lower(), [])
|
||||
else:
|
||||
all_checks = []
|
||||
for check_list in checks.values():
|
||||
all_checks.extend(check_list)
|
||||
for check_id in all_checks:
|
||||
if check_id not in check_req_map:
|
||||
check_req_map[check_id] = []
|
||||
check_req_map[check_id].append(req)
|
||||
|
||||
# Process findings using the provider-filtered check_req_map.
|
||||
# This ensures that for multi-provider dict checks, only the checks
|
||||
# belonging to the current provider produce output rows.
|
||||
for finding in findings:
|
||||
check_id = finding.check_id
|
||||
if check_id in check_req_map:
|
||||
for req in check_req_map[check_id]:
|
||||
row = self._build_row(finding, framework, req)
|
||||
try:
|
||||
self._data.append(self._row_model(**row))
|
||||
except Exception as e:
|
||||
logger.debug(f"Skipping row for {req.id}: {e}")
|
||||
|
||||
# Manual requirements (no checks or empty dict)
|
||||
for req in framework.requirements:
|
||||
checks = req.checks
|
||||
if self._provider:
|
||||
has_checks = bool(checks.get(self._provider.lower(), []))
|
||||
else:
|
||||
has_checks = any(checks.values())
|
||||
|
||||
if not has_checks:
|
||||
# Use a dummy finding-like namespace for manual rows
|
||||
row = self._build_row(
|
||||
_ManualFindingStub(), framework, req, is_manual=True
|
||||
)
|
||||
try:
|
||||
self._data.append(self._row_model(**row))
|
||||
except Exception as e:
|
||||
logger.debug(f"Skipping manual row for {req.id}: {e}")
|
||||
|
||||
def _create_file_descriptor(self, file_path: str) -> None:
|
||||
try:
|
||||
self._file_descriptor = open_file(file_path, "a")
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def batch_write_data_to_file(self) -> None:
|
||||
"""Write findings data to CSV."""
|
||||
try:
|
||||
if (
|
||||
getattr(self, "_file_descriptor", None)
|
||||
and not self._file_descriptor.closed
|
||||
and self._data
|
||||
):
|
||||
csv_writer = DictWriter(
|
||||
self._file_descriptor,
|
||||
fieldnames=[field.upper() for field in self._data[0].dict().keys()],
|
||||
delimiter=";",
|
||||
)
|
||||
if self._file_descriptor.tell() == 0:
|
||||
csv_writer.writeheader()
|
||||
for row in self._data:
|
||||
csv_writer.writerow({k.upper(): v for k, v in row.dict().items()})
|
||||
if self.close_file or self._from_cli:
|
||||
self._file_descriptor.close()
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
|
||||
class _ManualFindingStub:
|
||||
"""Minimal stub to satisfy _build_row for manual requirements."""
|
||||
|
||||
provider = ""
|
||||
account_uid = ""
|
||||
account_name = ""
|
||||
region = ""
|
||||
status = "MANUAL"
|
||||
status_extended = "Manual check"
|
||||
resource_uid = "manual_check"
|
||||
resource_name = "Manual check"
|
||||
check_id = "manual"
|
||||
muted = False
|
||||
@@ -15,7 +15,7 @@ from prowler.lib.check.models import (
|
||||
)
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.outputs.common import Status, fill_common_finding_data
|
||||
from prowler.lib.outputs.compliance.compliance import get_check_compliance
|
||||
from prowler.lib.outputs.compliance.compliance_check import get_check_compliance
|
||||
from prowler.lib.outputs.utils import unroll_tags
|
||||
from prowler.lib.utils.utils import dict_to_lowercase, get_nested_attribute
|
||||
from prowler.providers.common.provider import Provider
|
||||
|
||||
@@ -25,8 +25,8 @@ from prowler.lib.utils.utils import open_file, parse_json_file, print_boxes
|
||||
from prowler.providers.aws.config import (
|
||||
AWS_REGION_US_EAST_1,
|
||||
AWS_STS_GLOBAL_ENDPOINT_REGION,
|
||||
BOTO3_USER_AGENT_EXTRA,
|
||||
ROLE_SESSION_NAME,
|
||||
get_default_session_config,
|
||||
)
|
||||
from prowler.providers.aws.exceptions.exceptions import (
|
||||
AWSAccessKeyIDInvalidError,
|
||||
@@ -227,14 +227,15 @@ class AwsProvider(Provider):
|
||||
|
||||
# TODO: Use AwsSetUpSession ?????
|
||||
# Configure the initial AWS Session using the local credentials: profile or environment variables
|
||||
session_config = self.set_session_config(retries_max_attempts)
|
||||
aws_session = self.setup_session(
|
||||
mfa=mfa,
|
||||
profile=profile,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
session_config=session_config,
|
||||
)
|
||||
session_config = self.set_session_config(retries_max_attempts)
|
||||
# Current session and the original session points to the same session object until we get a new one, if needed
|
||||
self._session = AWSSession(
|
||||
current_session=aws_session,
|
||||
@@ -630,6 +631,7 @@ class AwsProvider(Provider):
|
||||
aws_access_key_id: str = None,
|
||||
aws_secret_access_key: str = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
session_config: Optional[Config] = None,
|
||||
) -> Session:
|
||||
"""
|
||||
setup_session sets up an AWS session using the provided credentials.
|
||||
@@ -640,6 +642,9 @@ class AwsProvider(Provider):
|
||||
- aws_access_key_id: The AWS access key ID.
|
||||
- aws_secret_access_key: The AWS secret access key.
|
||||
- aws_session_token: The AWS session token, optional.
|
||||
- session_config: Botocore Config applied as the session's default
|
||||
client config so every client created from the session inherits
|
||||
the Prowler user agent and retry settings.
|
||||
|
||||
Returns:
|
||||
- Session: The AWS session.
|
||||
@@ -650,6 +655,9 @@ class AwsProvider(Provider):
|
||||
try:
|
||||
logger.debug("Creating original session ...")
|
||||
|
||||
if session_config is None:
|
||||
session_config = AwsProvider.set_session_config(None)
|
||||
|
||||
session_arguments = {}
|
||||
if profile:
|
||||
session_arguments["profile_name"] = profile
|
||||
@@ -661,6 +669,7 @@ class AwsProvider(Provider):
|
||||
|
||||
if mfa:
|
||||
session = Session(**session_arguments)
|
||||
session._session.set_default_client_config(session_config)
|
||||
sts_client = session.client("sts")
|
||||
|
||||
# TODO: pass values from the input
|
||||
@@ -673,7 +682,7 @@ class AwsProvider(Provider):
|
||||
session_credentials = sts_client.get_session_token(
|
||||
**get_session_token_arguments
|
||||
)
|
||||
return Session(
|
||||
mfa_session = Session(
|
||||
aws_access_key_id=session_credentials["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=session_credentials["Credentials"][
|
||||
"SecretAccessKey"
|
||||
@@ -682,8 +691,12 @@ class AwsProvider(Provider):
|
||||
"SessionToken"
|
||||
],
|
||||
)
|
||||
mfa_session._session.set_default_client_config(session_config)
|
||||
return mfa_session
|
||||
else:
|
||||
return Session(**session_arguments)
|
||||
session = Session(**session_arguments)
|
||||
session._session.set_default_client_config(session_config)
|
||||
return session
|
||||
except Exception as error:
|
||||
logger.critical(
|
||||
f"AWSSetUpSessionError[{error.__traceback__.tb_lineno}]: {error}"
|
||||
@@ -698,6 +711,7 @@ class AwsProvider(Provider):
|
||||
identity: AWSIdentityInfo,
|
||||
assumed_role_configuration: AWSAssumeRoleConfiguration,
|
||||
session: AWSSession,
|
||||
session_config: Optional[Config] = None,
|
||||
) -> Session:
|
||||
"""
|
||||
Sets up an assumed session using the provided assumed role credentials.
|
||||
@@ -742,6 +756,13 @@ class AwsProvider(Provider):
|
||||
assumed_session = BotocoreSession()
|
||||
assumed_session._credentials = assumed_refreshable_credentials
|
||||
assumed_session.set_config_variable("region", identity.profile_region)
|
||||
if session_config is None:
|
||||
session_config = (
|
||||
session.session_config
|
||||
if session is not None
|
||||
else AwsProvider.set_session_config(None)
|
||||
)
|
||||
assumed_session.set_default_client_config(session_config)
|
||||
return Session(
|
||||
profile_name=identity.profile,
|
||||
botocore_session=assumed_session,
|
||||
@@ -870,7 +891,7 @@ class AwsProvider(Provider):
|
||||
|
||||
for region in enabled_regions:
|
||||
regional_client = self._session.current_session.client(
|
||||
service, region_name=region, config=self._session.session_config
|
||||
service, region_name=region
|
||||
)
|
||||
regional_client.region = region
|
||||
regional_clients[region] = regional_client
|
||||
@@ -1140,21 +1161,16 @@ class AwsProvider(Provider):
|
||||
Returns:
|
||||
- Config: The botocore Config object
|
||||
"""
|
||||
# Set the maximum retries for the standard retrier config
|
||||
default_session_config = Config(
|
||||
retries={"max_attempts": 3, "mode": "standard"},
|
||||
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
|
||||
)
|
||||
default_session_config = get_default_session_config()
|
||||
if retries_max_attempts:
|
||||
# Create the new config
|
||||
config = Config(
|
||||
retries={
|
||||
"max_attempts": retries_max_attempts,
|
||||
"mode": "standard",
|
||||
},
|
||||
default_session_config = default_session_config.merge(
|
||||
Config(
|
||||
retries={
|
||||
"max_attempts": retries_max_attempts,
|
||||
"mode": "standard",
|
||||
},
|
||||
)
|
||||
)
|
||||
# Merge the new configuration
|
||||
default_session_config = default_session_config.merge(config)
|
||||
|
||||
return default_session_config
|
||||
|
||||
@@ -1425,6 +1441,9 @@ class AwsProvider(Provider):
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
|
||||
caller_identity = AwsProvider.validate_credentials(session, aws_region)
|
||||
# Do an extra validation if the AWS account ID is provided
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
import os
|
||||
|
||||
from botocore.config import Config
|
||||
|
||||
AWS_STS_GLOBAL_ENDPOINT_REGION = "us-east-1"
|
||||
AWS_REGION_US_EAST_1 = "us-east-1"
|
||||
BOTO3_USER_AGENT_EXTRA = os.getenv("PROWLER_AWS_BOTO3_USER_AGENT_EXTRA", "APN_1826889")
|
||||
ROLE_SESSION_NAME = "ProwlerAssessmentSession"
|
||||
|
||||
|
||||
def get_default_session_config() -> Config:
|
||||
return Config(
|
||||
user_agent_extra=BOTO3_USER_AGENT_EXTRA,
|
||||
retries={"max_attempts": 3, "mode": "standard"},
|
||||
)
|
||||
|
||||
@@ -56,9 +56,7 @@ def quick_inventory(provider: AwsProvider, args):
|
||||
try:
|
||||
# Scan IAM only once
|
||||
if not iam_was_scanned:
|
||||
global_resources.extend(
|
||||
get_iam_resources(provider.session.current_session)
|
||||
)
|
||||
global_resources.extend(get_iam_resources(provider))
|
||||
iam_was_scanned = True
|
||||
|
||||
# Get regional S3 buckets since none-tagged buckets are not supported by the resourcegroupstaggingapi
|
||||
@@ -312,8 +310,8 @@ def create_output(resources: list, provider: AwsProvider, args):
|
||||
if args.output_bucket:
|
||||
output_bucket = args.output_bucket
|
||||
bucket_session = provider.session.current_session
|
||||
# Check if -D was input
|
||||
elif args.output_bucket_no_assume:
|
||||
# The outer condition guarantees -D was input when -B was not
|
||||
else:
|
||||
output_bucket = args.output_bucket_no_assume
|
||||
bucket_session = provider.session.original_session
|
||||
|
||||
@@ -375,9 +373,9 @@ def get_regional_buckets(provider: AwsProvider, region: str) -> list:
|
||||
return regional_buckets
|
||||
|
||||
|
||||
def get_iam_resources(session) -> list:
|
||||
def get_iam_resources(provider: AwsProvider) -> list:
|
||||
iam_resources = []
|
||||
iam_client = session.client("iam")
|
||||
iam_client = provider.session.current_session.client("iam")
|
||||
try:
|
||||
get_roles_paginator = iam_client.get_paginator("list_roles")
|
||||
for page in get_roles_paginator.paginate():
|
||||
|
||||
@@ -111,6 +111,13 @@ class S3:
|
||||
- None
|
||||
"""
|
||||
if session:
|
||||
# Preserve the caller's existing default config (and the
|
||||
# retries_max_attempts already baked into it) instead of clobbering
|
||||
# it with a freshly built one.
|
||||
if session._session.get_default_client_config() is None:
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(retries_max_attempts)
|
||||
)
|
||||
self._session = session.client(__class__.__name__.lower())
|
||||
else:
|
||||
aws_setup_session = AwsSetUpSession(
|
||||
@@ -127,8 +134,7 @@ class S3:
|
||||
regions=regions,
|
||||
)
|
||||
self._session = aws_setup_session._session.current_session.client(
|
||||
__class__.__name__.lower(),
|
||||
config=aws_setup_session._session.session_config,
|
||||
__class__.__name__.lower()
|
||||
)
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
@@ -313,6 +319,9 @@ class S3:
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
s3_client = session.client(__class__.__name__.lower())
|
||||
if "s3://" in bucket_name:
|
||||
bucket_name = bucket_name.removeprefix("s3://")
|
||||
|
||||
@@ -148,6 +148,13 @@ class SecurityHub:
|
||||
regions=regions,
|
||||
)
|
||||
self._session = aws_setup_session._session.current_session
|
||||
# Only install the Prowler default config when the caller-supplied
|
||||
# session does not already carry one — overwriting would drop the
|
||||
# provider's retries_max_attempts value.
|
||||
if aws_session and self._session._session.get_default_client_config() is None:
|
||||
self._session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(retries_max_attempts)
|
||||
)
|
||||
self._aws_account_id = aws_account_id
|
||||
if not aws_partition:
|
||||
aws_partition = AwsProvider.validate_credentials(
|
||||
@@ -235,7 +242,7 @@ class SecurityHub:
|
||||
|
||||
Args:
|
||||
region (str): AWS region to check.
|
||||
session (Session): AWS session object.
|
||||
session (Session): AWS session object. Expected to carry the Prowler default client config.
|
||||
aws_account_id (str): AWS account ID.
|
||||
aws_partition (str): AWS partition.
|
||||
|
||||
@@ -540,6 +547,9 @@ class SecurityHub:
|
||||
region_name=aws_region,
|
||||
profile_name=profile,
|
||||
)
|
||||
session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
|
||||
all_regions = AwsProvider.get_available_aws_service_regions(
|
||||
service="securityhub", partition=aws_partition
|
||||
|
||||
@@ -32,7 +32,13 @@ class AWSService:
|
||||
def is_failed_check(cls, check_id, arn):
|
||||
return (check_id.split(".")[-1], arn) in cls.failed_checks
|
||||
|
||||
def __init__(self, service: str, provider: AwsProvider, global_service=False):
|
||||
def __init__(
|
||||
self,
|
||||
service: str,
|
||||
provider: AwsProvider,
|
||||
global_service=False,
|
||||
region: str = None,
|
||||
):
|
||||
# Audit Information
|
||||
# Do we need to store the whole provider?
|
||||
self.provider = provider
|
||||
@@ -61,7 +67,7 @@ class AWSService:
|
||||
# Get a single region and client if the service needs it (e.g. AWS Global Service)
|
||||
# We cannot include this within an else because some services needs both the regional_clients
|
||||
# and a single client like S3
|
||||
self.region = provider.get_default_region(
|
||||
self.region = region or provider.get_default_region(
|
||||
self.service, global_service=global_service
|
||||
)
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
|
||||
@@ -73,15 +73,15 @@ class AwsSetUpSession:
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
# Setup the AWS session
|
||||
session_config = AwsProvider.set_session_config(retries_max_attempts)
|
||||
aws_session = AwsProvider.setup_session(
|
||||
mfa=mfa,
|
||||
profile=profile,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
session_config=session_config,
|
||||
)
|
||||
session_config = AwsProvider.set_session_config(retries_max_attempts)
|
||||
self._session = AWSSession(
|
||||
current_session=aws_session,
|
||||
session_config=session_config,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from concurrent.futures import as_completed
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic.v1 import BaseModel
|
||||
@@ -14,9 +15,9 @@ class Codebuild(AWSService):
|
||||
super().__init__(__class__.__name__, provider)
|
||||
self.projects = {}
|
||||
self.__threading_call__(self._list_projects)
|
||||
self.__threading_call__(self._list_builds_for_project, self.projects.values())
|
||||
self.__threading_call__(self._batch_get_builds, self.projects.values())
|
||||
self.__threading_call__(self._batch_get_projects, self.projects.values())
|
||||
self.__threading_call__(self._list_builds_for_project)
|
||||
self.__threading_call__(self._batch_get_builds)
|
||||
self.__threading_call__(self._batch_get_projects)
|
||||
self.report_groups = {}
|
||||
self.__threading_call__(self._list_report_groups)
|
||||
self.__threading_call__(
|
||||
@@ -44,10 +45,8 @@ class Codebuild(AWSService):
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _list_builds_for_project(self, project):
|
||||
logger.info("Codebuild - Listing builds...")
|
||||
def _fetch_project_last_build(self, regional_client, project):
|
||||
try:
|
||||
regional_client = self.regional_clients[project.region]
|
||||
build_ids = regional_client.list_builds_for_project(
|
||||
projectName=project.name
|
||||
).get("ids", [])
|
||||
@@ -58,28 +57,99 @@ class Codebuild(AWSService):
|
||||
f"{project.region}: {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _batch_get_builds(self, project):
|
||||
logger.info("Codebuild - Getting builds...")
|
||||
def _list_builds_for_project(self, regional_client):
|
||||
logger.info("Codebuild - Listing builds...")
|
||||
try:
|
||||
if project.last_build and project.last_build.id:
|
||||
regional_client = self.regional_clients[project.region]
|
||||
builds_by_id = regional_client.batch_get_builds(
|
||||
ids=[project.last_build.id]
|
||||
).get("builds", [])
|
||||
if len(builds_by_id) > 0:
|
||||
project.last_invoked_time = builds_by_id[0].get("endTime")
|
||||
regional_projects = [
|
||||
project
|
||||
for project in self.projects.values()
|
||||
if project.region == regional_client.region
|
||||
]
|
||||
|
||||
# list_builds_for_project has no batch API equivalent, so reuse the
|
||||
# shared thread pool to issue per-project calls in parallel within
|
||||
# this region — preserving the wall-clock performance of the
|
||||
# previous implementation.
|
||||
futures = [
|
||||
self.thread_pool.submit(
|
||||
self._fetch_project_last_build, regional_client, project
|
||||
)
|
||||
for project in regional_projects
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region}: {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _batch_get_projects(self, project):
|
||||
def _batch_get_builds(self, regional_client):
|
||||
logger.info("Codebuild - Getting builds...")
|
||||
try:
|
||||
# Collect all build IDs for this region
|
||||
build_id_to_project = {}
|
||||
for project in self.projects.values():
|
||||
if (
|
||||
project.region == regional_client.region
|
||||
and project.last_build
|
||||
and project.last_build.id
|
||||
):
|
||||
build_id_to_project[project.last_build.id] = project
|
||||
|
||||
if not build_id_to_project:
|
||||
return
|
||||
|
||||
build_ids = list(build_id_to_project.keys())
|
||||
|
||||
# batch_get_builds supports up to 100 IDs per call
|
||||
for i in range(0, len(build_ids), 100):
|
||||
batch = build_ids[i : i + 100]
|
||||
response = regional_client.batch_get_builds(ids=batch)
|
||||
for build_info in response.get("builds", []):
|
||||
build_id = build_info.get("id")
|
||||
if build_id in build_id_to_project:
|
||||
end_time = build_info.get("endTime")
|
||||
if end_time:
|
||||
build_id_to_project[build_id].last_invoked_time = end_time
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _batch_get_projects(self, regional_client):
|
||||
logger.info("Codebuild - Getting projects...")
|
||||
try:
|
||||
regional_client = self.regional_clients[project.region]
|
||||
project_info = regional_client.batch_get_projects(names=[project.name])[
|
||||
"projects"
|
||||
][0]
|
||||
# Collect all project names for this region
|
||||
regional_projects = {
|
||||
arn: project
|
||||
for arn, project in self.projects.items()
|
||||
if project.region == regional_client.region
|
||||
}
|
||||
if not regional_projects:
|
||||
return
|
||||
|
||||
project_names = [project.name for project in regional_projects.values()]
|
||||
|
||||
# batch_get_projects supports up to 100 names per call
|
||||
for i in range(0, len(project_names), 100):
|
||||
batch = project_names[i : i + 100]
|
||||
response = regional_client.batch_get_projects(names=batch)
|
||||
for project_info in response.get("projects", []):
|
||||
project_arn = project_info.get("arn")
|
||||
if project_arn in regional_projects:
|
||||
self._parse_project_info(
|
||||
regional_projects[project_arn], project_info
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _parse_project_info(self, project, project_info):
|
||||
try:
|
||||
project.buildspec = project_info["source"].get("buildspec")
|
||||
if project_info["source"]["type"] != "NO_SOURCE":
|
||||
project.source = Source(
|
||||
|
||||
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class GlobalAccelerator(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__(__class__.__name__, provider)
|
||||
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
|
||||
# That is, for example, specify --region us-west-2 on AWS CLI commands.
|
||||
region = "us-west-2" if provider.identity.partition == "aws" else None
|
||||
super().__init__(__class__.__name__, provider, region=region)
|
||||
self.accelerators = {}
|
||||
if self.audited_partition == "aws":
|
||||
# Global Accelerator is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US West (Oregon) Region to create, update, or otherwise work with accelerators.
|
||||
# That is, for example, specify --region us-west-2 on AWS CLI commands.
|
||||
self.region = "us-west-2"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_accelerators()
|
||||
self.__threading_call__(self._list_tags, self.accelerators.values())
|
||||
|
||||
|
||||
@@ -176,14 +176,12 @@ class RecordSet(BaseModel):
|
||||
|
||||
class Route53Domains(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__(__class__.__name__, provider)
|
||||
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
|
||||
region = "us-east-1" if provider.identity.partition == "aws" else None
|
||||
super().__init__(__class__.__name__, provider, region=region)
|
||||
self.domains = {}
|
||||
if self.audited_partition == "aws":
|
||||
# Route53Domains is a global service that supports endpoints in multiple AWS Regions
|
||||
# but you must specify the US East (N. Virginia) Region to create, update, or otherwise work with domains.
|
||||
self.region = "us-east-1"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_domains()
|
||||
self._get_domain_detail()
|
||||
self._list_tags_for_domain()
|
||||
|
||||
@@ -9,20 +9,20 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class TrustedAdvisor(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__("support", provider)
|
||||
# Support API is not available in China Partition
|
||||
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
|
||||
partition = provider.identity.partition
|
||||
if partition == "aws":
|
||||
support_region = "us-east-1"
|
||||
elif partition == "aws-cn":
|
||||
support_region = None
|
||||
else:
|
||||
support_region = "us-gov-west-1"
|
||||
super().__init__("support", provider, region=support_region)
|
||||
self.account_arn_template = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:account"
|
||||
self.checks = []
|
||||
self.premium_support = PremiumSupport(enabled=False)
|
||||
# Support API is not available in China Partition
|
||||
# But only in us-east-1 or us-gov-west-1 https://docs.aws.amazon.com/general/latest/gr/awssupport.html
|
||||
if self.audited_partition != "aws-cn":
|
||||
if self.audited_partition == "aws":
|
||||
support_region = "us-east-1"
|
||||
else:
|
||||
support_region = "us-gov-west-1"
|
||||
self.client = self.session.client(self.service, region_name=support_region)
|
||||
self.client.region = support_region
|
||||
self._describe_services()
|
||||
if getattr(self.premium_support, "enabled", False):
|
||||
self._describe_trusted_advisor_checks()
|
||||
@@ -34,13 +34,13 @@ class TrustedAdvisor(AWSService):
|
||||
for check in self.client.describe_trusted_advisor_checks(language="en").get(
|
||||
"checks", []
|
||||
):
|
||||
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.client.region}:{self.audited_account}:check/{check['id']}"
|
||||
check_arn = f"arn:{self.audited_partition}:trusted-advisor:{self.region}:{self.audited_account}:check/{check['id']}"
|
||||
self.checks.append(
|
||||
Check(
|
||||
id=check["id"],
|
||||
name=check["name"],
|
||||
arn=check_arn,
|
||||
region=self.client.region,
|
||||
region=self.region,
|
||||
)
|
||||
)
|
||||
except ClientError as error:
|
||||
@@ -50,22 +50,22 @@ class TrustedAdvisor(AWSService):
|
||||
== "Amazon Web Services Premium Support Subscription is required to use this service."
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _describe_trusted_advisor_check_result(self):
|
||||
logger.info("TrustedAdvisor - Describing Check Result...")
|
||||
try:
|
||||
for check in self.checks:
|
||||
if check.region == self.client.region:
|
||||
if check.region == self.region:
|
||||
try:
|
||||
response = self.client.describe_trusted_advisor_check_result(
|
||||
checkId=check.id
|
||||
@@ -78,11 +78,11 @@ class TrustedAdvisor(AWSService):
|
||||
== "InvalidParameterValueException"
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
f"{self.client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
|
||||
)
|
||||
|
||||
def _describe_services(self):
|
||||
|
||||
@@ -9,15 +9,13 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class WAF(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__("waf", provider)
|
||||
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
|
||||
region = "us-east-1" if provider.identity.partition == "aws" else None
|
||||
super().__init__("waf", provider, region=region)
|
||||
self.rules = {}
|
||||
self.rule_groups = {}
|
||||
self.web_acls = {}
|
||||
if self.audited_partition == "aws":
|
||||
# AWS WAF is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL and any resources used in the web ACL, such as rule groups, IP sets, and regex pattern sets.
|
||||
self.region = "us-east-1"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_rules()
|
||||
self.__threading_call__(self._get_rule, self.rules.values())
|
||||
self._list_rule_groups()
|
||||
|
||||
@@ -11,13 +11,11 @@ from prowler.providers.aws.lib.service.service import AWSService
|
||||
|
||||
class WAFv2(AWSService):
|
||||
def __init__(self, provider):
|
||||
# Call AWSService's __init__
|
||||
super().__init__(__class__.__name__, provider)
|
||||
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
|
||||
region = "us-east-1" if provider.identity.partition == "aws" else None
|
||||
super().__init__(__class__.__name__, provider, region=region)
|
||||
self.web_acls = {}
|
||||
if self.audited_partition == "aws":
|
||||
# AWS WAFv2 is available globally for CloudFront distributions, but you must use the Region US East (N. Virginia) to create your web ACL.
|
||||
self.region = "us-east-1"
|
||||
self.client = self.session.client(self.service, self.region)
|
||||
self._list_web_acls_global()
|
||||
self.__threading_call__(self._list_web_acls_regional)
|
||||
self.__threading_call__(self._get_web_acl, self.web_acls.values())
|
||||
|
||||
@@ -436,6 +436,33 @@ class Test_Config:
|
||||
assert "csa_ccm_4.0" in aws_frameworks
|
||||
assert "csa_ccm_4.0" not in kubernetes_frameworks
|
||||
|
||||
def test_get_available_compliance_frameworks_no_provider_includes_universals(self):
|
||||
"""Regression test for the variable shadowing bug.
|
||||
|
||||
Previously, the inner ``for provider in providers`` loop shadowed
|
||||
the outer ``provider`` parameter. When called without a provider,
|
||||
the post-loop ``if provider:`` branch wrongly applied
|
||||
``framework.supports_provider(<last provider iterated>)`` and
|
||||
excluded universal frameworks from the result.
|
||||
|
||||
Result: the parser-level ``available_compliance_frameworks``
|
||||
constant was missing universal frameworks like ``csa_ccm_4.0``,
|
||||
which made ``--compliance csa_ccm_4.0`` reject the choice.
|
||||
"""
|
||||
all_frameworks = get_available_compliance_frameworks()
|
||||
assert "csa_ccm_4.0" in all_frameworks
|
||||
|
||||
def test_get_available_compliance_frameworks_does_not_mutate_provider_param(self):
|
||||
"""Calling with a specific provider must not affect a subsequent
|
||||
call without provider. Validates that the loop variable rename
|
||||
prevents leaking state between calls."""
|
||||
# Force an iteration over multiple providers first
|
||||
get_available_compliance_frameworks("kubernetes")
|
||||
# Then a no-provider call must still include universals supported
|
||||
# by ANY provider (not filtered by some leaked value)
|
||||
all_frameworks = get_available_compliance_frameworks()
|
||||
assert "csa_ccm_4.0" in all_frameworks
|
||||
|
||||
def test_load_and_validate_config_file_aws(self):
|
||||
path = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
config_test_file = f"{path}/fixtures/config.yaml"
|
||||
|
||||
@@ -675,3 +675,177 @@ class TestCheckLoader:
|
||||
)
|
||||
assert CLOUDTRAIL_THREAT_DETECTION_ENUMERATION_NAME not in result
|
||||
assert S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME in result
|
||||
|
||||
def test_load_checks_to_execute_universal_framework_takes_precedence(self):
|
||||
"""When ``--compliance <fw>`` matches a universal framework, the
|
||||
loader must source checks from ``universal_frameworks[fw].requirements[*]
|
||||
.checks[provider]`` and NOT fall through to ``bulk_compliance_frameworks``.
|
||||
|
||||
This is the path added by PR #10301 in checks_loader.py.
|
||||
"""
|
||||
from prowler.lib.check.compliance_models import (
|
||||
ComplianceFramework,
|
||||
UniversalComplianceRequirement,
|
||||
)
|
||||
|
||||
bulk_checks_metadata = {
|
||||
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
|
||||
}
|
||||
|
||||
universal_framework = ComplianceFramework(
|
||||
framework="csa_ccm",
|
||||
name="CSA CCM 4.0",
|
||||
version="4.0",
|
||||
description="Cloud Controls Matrix",
|
||||
requirements=[
|
||||
UniversalComplianceRequirement(
|
||||
id="A&A-01",
|
||||
description="Audit & Assurance",
|
||||
attributes={},
|
||||
checks={"aws": [S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME]},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
|
||||
return_value=bulk_checks_metadata,
|
||||
):
|
||||
result = load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks={}, # legacy empty
|
||||
compliance_frameworks=["csa_ccm_4.0"],
|
||||
provider=self.provider,
|
||||
universal_frameworks={"csa_ccm_4.0": universal_framework},
|
||||
)
|
||||
|
||||
assert result == {S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME}
|
||||
|
||||
def test_load_checks_to_execute_universal_filters_by_provider(self):
|
||||
"""A universal requirement may declare checks for several
|
||||
providers; the loader must only return those for the active
|
||||
provider key (lowercased)."""
|
||||
from prowler.lib.check.compliance_models import (
|
||||
ComplianceFramework,
|
||||
UniversalComplianceRequirement,
|
||||
)
|
||||
|
||||
bulk_checks_metadata = {
|
||||
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
|
||||
}
|
||||
|
||||
# The same requirement maps a different check per provider.
|
||||
# Only the AWS one must be returned for provider="aws".
|
||||
universal_framework = ComplianceFramework(
|
||||
framework="csa_ccm",
|
||||
name="CSA CCM 4.0",
|
||||
version="4.0",
|
||||
description="Cloud Controls Matrix",
|
||||
requirements=[
|
||||
UniversalComplianceRequirement(
|
||||
id="A&A-02",
|
||||
description="Multi-provider req",
|
||||
attributes={},
|
||||
checks={
|
||||
"aws": [S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME],
|
||||
"azure": ["azure_only_check"],
|
||||
"gcp": ["gcp_only_check"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
|
||||
return_value=bulk_checks_metadata,
|
||||
):
|
||||
result = load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks={},
|
||||
compliance_frameworks=["csa_ccm_4.0"],
|
||||
provider=self.provider, # "aws"
|
||||
universal_frameworks={"csa_ccm_4.0": universal_framework},
|
||||
)
|
||||
|
||||
assert S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME in result
|
||||
assert "azure_only_check" not in result
|
||||
assert "gcp_only_check" not in result
|
||||
|
||||
def test_load_checks_to_execute_universal_no_match_falls_back_to_legacy(self):
|
||||
"""If the requested compliance framework is not present in
|
||||
``universal_frameworks``, the loader must fall back to the
|
||||
legacy ``bulk_compliance_frameworks`` lookup."""
|
||||
bulk_checks_metadata = {
|
||||
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
|
||||
}
|
||||
bulk_compliance_frameworks = {
|
||||
"soc2_aws": Compliance(
|
||||
Framework="SOC2",
|
||||
Name="SOC2",
|
||||
Provider="aws",
|
||||
Version="2.0",
|
||||
Description="x",
|
||||
Requirements=[
|
||||
Compliance_Requirement(
|
||||
Checks=[S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME],
|
||||
Id="",
|
||||
Description="",
|
||||
Attributes=[],
|
||||
)
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
|
||||
return_value=bulk_checks_metadata,
|
||||
):
|
||||
result = load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks=bulk_compliance_frameworks,
|
||||
compliance_frameworks=["soc2_aws"],
|
||||
provider=self.provider,
|
||||
universal_frameworks={"some_other_universal_fw": object()},
|
||||
)
|
||||
|
||||
assert result == {S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME}
|
||||
|
||||
def test_load_checks_to_execute_universal_unknown_provider_returns_empty(self):
|
||||
"""If the universal requirement has no checks for the active
|
||||
provider, no checks are picked up for that requirement."""
|
||||
from prowler.lib.check.compliance_models import (
|
||||
ComplianceFramework,
|
||||
UniversalComplianceRequirement,
|
||||
)
|
||||
|
||||
bulk_checks_metadata = {
|
||||
S3_BUCKET_LEVEL_PUBLIC_ACCESS_BLOCK_NAME: self.get_custom_check_s3_metadata()
|
||||
}
|
||||
universal_framework = ComplianceFramework(
|
||||
framework="csa_ccm",
|
||||
name="CSA CCM 4.0",
|
||||
version="4.0",
|
||||
description="Cloud Controls Matrix",
|
||||
requirements=[
|
||||
UniversalComplianceRequirement(
|
||||
id="A&A-03",
|
||||
description="Only Azure",
|
||||
attributes={},
|
||||
checks={"azure": ["azure_only_check"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"prowler.lib.check.checks_loader.CheckMetadata.get_bulk",
|
||||
return_value=bulk_checks_metadata,
|
||||
):
|
||||
result = load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks={},
|
||||
compliance_frameworks=["csa_ccm_4.0"],
|
||||
provider=self.provider, # "aws" — no checks declared
|
||||
universal_frameworks={"csa_ccm_4.0": universal_framework},
|
||||
)
|
||||
|
||||
assert result == set()
|
||||
|
||||
@@ -442,3 +442,123 @@ class TestComplianceOutput:
|
||||
)
|
||||
|
||||
assert compliance_output.file_extension == ".csv"
|
||||
|
||||
|
||||
class TestComplianceCheckHelperModule:
|
||||
"""Tests for the new ``compliance_check`` leaf module that hosts
|
||||
``get_check_compliance``.
|
||||
|
||||
This module exists to break the cyclic import chain
|
||||
``finding -> compliance.compliance -> universal.* -> finding`` that
|
||||
CodeQL flagged. It must be:
|
||||
- importable directly without pulling in the universal pipeline
|
||||
- re-exported by ``compliance.compliance`` for backward compatibility
|
||||
- the SAME function object, regardless of import path
|
||||
"""
|
||||
|
||||
def test_module_is_importable_directly(self):
|
||||
"""The helper module must be importable on its own — it is the
|
||||
leaf used by ``finding.py`` to break the cyclic import chain."""
|
||||
from prowler.lib.outputs.compliance import compliance_check
|
||||
|
||||
assert hasattr(compliance_check, "get_check_compliance")
|
||||
assert callable(compliance_check.get_check_compliance)
|
||||
|
||||
def test_helper_module_only_depends_on_check_models_and_logger(self):
|
||||
"""The helper must not pull in universal pipeline modules; that
|
||||
was the whole point of extracting it. Inspecting the module's
|
||||
own imports keeps it honest without polluting ``sys.modules``."""
|
||||
import inspect
|
||||
|
||||
from prowler.lib.outputs.compliance import compliance_check
|
||||
|
||||
source = inspect.getsource(compliance_check)
|
||||
# Only these two prowler imports are allowed in the leaf module
|
||||
assert "from prowler.lib.check.models import Check_Report" in source
|
||||
assert "from prowler.lib.logger import logger" in source
|
||||
# And NOT these (would re-introduce the cycle):
|
||||
assert "from prowler.lib.outputs.compliance.universal" not in source
|
||||
assert "from prowler.lib.outputs.finding" not in source
|
||||
assert "from prowler.lib.outputs.ocsf" not in source
|
||||
|
||||
def test_re_export_from_compliance_compliance(self):
|
||||
"""``compliance.compliance.get_check_compliance`` must point to
|
||||
the same function as ``compliance.compliance_check.get_check_compliance``."""
|
||||
from prowler.lib.outputs.compliance.compliance import (
|
||||
get_check_compliance as via_compliance,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.compliance_check import (
|
||||
get_check_compliance as via_helper,
|
||||
)
|
||||
|
||||
assert via_compliance is via_helper
|
||||
|
||||
def test_re_export_from_finding_module(self):
|
||||
"""``finding.get_check_compliance`` must point to the same
|
||||
function. Test mocks rely on this attribute existing on the
|
||||
``prowler.lib.outputs.finding`` module."""
|
||||
from prowler.lib.outputs.compliance.compliance_check import (
|
||||
get_check_compliance as via_helper,
|
||||
)
|
||||
from prowler.lib.outputs.finding import get_check_compliance as via_finding
|
||||
|
||||
assert via_finding is via_helper
|
||||
|
||||
def test_returns_empty_dict_on_unknown_check(self):
|
||||
"""Sanity test of the function logic via the helper module."""
|
||||
from prowler.lib.outputs.compliance.compliance_check import (
|
||||
get_check_compliance,
|
||||
)
|
||||
|
||||
finding = mock.MagicMock()
|
||||
finding.check_metadata.CheckID = "unknown_check_id"
|
||||
result = get_check_compliance(finding, "aws", {})
|
||||
assert result == {}
|
||||
|
||||
def test_filters_by_provider(self):
|
||||
"""The function returns frameworks only for the matching provider."""
|
||||
from prowler.lib.outputs.compliance.compliance_check import (
|
||||
get_check_compliance,
|
||||
)
|
||||
|
||||
compliance_aws = mock.MagicMock(
|
||||
Framework="CIS",
|
||||
Version="1.4",
|
||||
Provider="AWS",
|
||||
Requirements=[mock.MagicMock(Id="2.1.3")],
|
||||
)
|
||||
compliance_azure = mock.MagicMock(
|
||||
Framework="CIS",
|
||||
Version="2.0",
|
||||
Provider="Azure",
|
||||
Requirements=[mock.MagicMock(Id="9.1")],
|
||||
)
|
||||
finding = mock.MagicMock()
|
||||
finding.check_metadata.CheckID = "shared_check"
|
||||
bulk = {
|
||||
"shared_check": mock.MagicMock(
|
||||
Compliance=[compliance_aws, compliance_azure]
|
||||
)
|
||||
}
|
||||
|
||||
# Only AWS frameworks come back
|
||||
result = get_check_compliance(finding, "aws", bulk)
|
||||
assert "CIS-1.4" in result
|
||||
assert "CIS-2.0" not in result
|
||||
|
||||
def test_returns_empty_dict_on_exception(self):
|
||||
"""If iteration raises, the function logs the error and returns
|
||||
an empty dict (defensive behaviour)."""
|
||||
from prowler.lib.outputs.compliance.compliance_check import (
|
||||
get_check_compliance,
|
||||
)
|
||||
|
||||
# bulk_checks_metadata that raises when accessed → defensive path
|
||||
class Boom:
|
||||
def __contains__(self, _key):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
finding = mock.MagicMock()
|
||||
finding.check_metadata.CheckID = "any"
|
||||
result = get_check_compliance(finding, "aws", Boom())
|
||||
assert result == {}
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
"""Tests for display_compliance_table dispatch logic.
|
||||
|
||||
Validates that each compliance framework name is routed to the correct
|
||||
table renderer via startswith matching, and that the universal early-return
|
||||
takes precedence when applicable.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prowler.lib.check.compliance_models import (
|
||||
ComplianceFramework,
|
||||
OutputsConfig,
|
||||
TableConfig,
|
||||
UniversalComplianceRequirement,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.compliance import display_compliance_table
|
||||
|
||||
MODULE = "prowler.lib.outputs.compliance.compliance"
|
||||
|
||||
# Common args shared by every call — the actual values don't matter
|
||||
# because we mock the downstream renderers.
|
||||
_COMMON = dict(
|
||||
findings=[],
|
||||
bulk_checks_metadata={},
|
||||
output_filename="out",
|
||||
output_directory="/tmp",
|
||||
compliance_overview=False,
|
||||
)
|
||||
|
||||
|
||||
# ── Dispatch to legacy table renderers ───────────────────────────────
|
||||
|
||||
|
||||
class TestDispatchStartswith:
|
||||
"""Each framework prefix must route to exactly one renderer."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
[
|
||||
"cis_1.4_aws",
|
||||
"cis_2.0_azure",
|
||||
"cis_3.0_gcp",
|
||||
"cis_6.0_m365",
|
||||
"cis_1.10_kubernetes",
|
||||
],
|
||||
)
|
||||
@patch(f"{MODULE}.get_cis_table")
|
||||
def test_cis_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
["ens_rd2022_aws", "ens_rd2022_azure", "ens_rd2022_gcp"],
|
||||
)
|
||||
@patch(f"{MODULE}.get_ens_table")
|
||||
def test_ens_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
["mitre_attack_aws", "mitre_attack_azure", "mitre_attack_gcp"],
|
||||
)
|
||||
@patch(f"{MODULE}.get_mitre_attack_table")
|
||||
def test_mitre_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
["kisa_isms_p_2023_aws", "kisa_isms_p_2023_korean_aws"],
|
||||
)
|
||||
@patch(f"{MODULE}.get_kisa_ismsp_table")
|
||||
def test_kisa_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
[
|
||||
"prowler_threatscore_aws",
|
||||
"prowler_threatscore_azure",
|
||||
"prowler_threatscore_gcp",
|
||||
"prowler_threatscore_kubernetes",
|
||||
"prowler_threatscore_m365",
|
||||
"prowler_threatscore_alibabacloud",
|
||||
],
|
||||
)
|
||||
@patch(f"{MODULE}.get_prowler_threatscore_table")
|
||||
def test_threatscore_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
[
|
||||
"csa_ccm_4.0_aws",
|
||||
"csa_ccm_4.0_azure",
|
||||
"csa_ccm_4.0_gcp",
|
||||
"csa_ccm_4.0_oraclecloud",
|
||||
"csa_ccm_4.0_alibabacloud",
|
||||
],
|
||||
)
|
||||
@patch(f"{MODULE}.get_csa_table")
|
||||
def test_csa_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
["c5_aws", "c5_azure", "c5_gcp"],
|
||||
)
|
||||
@patch(f"{MODULE}.get_c5_table")
|
||||
def test_c5_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"framework_name",
|
||||
[
|
||||
"soc2_aws",
|
||||
"hipaa_aws",
|
||||
"gdpr_aws",
|
||||
"nist_800_53_revision_4_aws",
|
||||
"pci_3.2.1_aws",
|
||||
"iso27001_2013_aws",
|
||||
"aws_well_architected_framework_security_pillar_aws",
|
||||
"fedramp_low_revision_4_aws",
|
||||
"cisa_aws",
|
||||
],
|
||||
)
|
||||
@patch(f"{MODULE}.get_generic_compliance_table")
|
||||
def test_generic_dispatch(self, mock_fn, framework_name):
|
||||
display_compliance_table(compliance_framework=framework_name, **_COMMON)
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
|
||||
# ── No false matches (the old `in` bug) ─────────────────────────────
|
||||
|
||||
|
||||
class TestNoFalseSubstringMatches:
|
||||
"""Frameworks that previously could false-match with `in` must NOT
|
||||
be routed to the wrong renderer now that we use startswith."""
|
||||
|
||||
@patch(f"{MODULE}.get_ens_table")
|
||||
@patch(f"{MODULE}.get_generic_compliance_table")
|
||||
def test_cisa_does_not_match_cis(self, mock_generic, mock_cis):
|
||||
"""'cisa_aws' must NOT match startswith('cis_')."""
|
||||
display_compliance_table(compliance_framework="cisa_aws", **_COMMON)
|
||||
mock_generic.assert_called_once()
|
||||
mock_cis.assert_not_called()
|
||||
|
||||
@patch(f"{MODULE}.get_prowler_threatscore_table")
|
||||
@patch(f"{MODULE}.get_generic_compliance_table")
|
||||
def test_threatscore_prefix_not_partial(self, mock_generic, mock_ts):
|
||||
"""A hypothetical 'threatscore_custom_aws' must NOT match
|
||||
startswith('prowler_threatscore_')."""
|
||||
display_compliance_table(
|
||||
compliance_framework="threatscore_custom_aws", **_COMMON
|
||||
)
|
||||
mock_generic.assert_called_once()
|
||||
mock_ts.assert_not_called()
|
||||
|
||||
@patch(f"{MODULE}.get_ens_table")
|
||||
@patch(f"{MODULE}.get_prowler_threatscore_table")
|
||||
def test_prowler_threatscore_does_not_match_ens(self, mock_ts, mock_ens):
|
||||
"""'prowler_threatscore_aws' must hit threatscore, never ens."""
|
||||
display_compliance_table(
|
||||
compliance_framework="prowler_threatscore_aws", **_COMMON
|
||||
)
|
||||
mock_ts.assert_called_once()
|
||||
mock_ens.assert_not_called()
|
||||
|
||||
|
||||
# ── Universal early-return ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestUniversalEarlyReturn:
|
||||
"""The universal path must take precedence over the elif chain."""
|
||||
|
||||
@staticmethod
|
||||
def _make_fw():
|
||||
return ComplianceFramework(
|
||||
framework="CIS",
|
||||
name="CIS",
|
||||
provider="AWS",
|
||||
version="5.0",
|
||||
description="d",
|
||||
requirements=[
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="d",
|
||||
attributes={},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
],
|
||||
outputs=OutputsConfig(table_config=TableConfig(group_by="_default")),
|
||||
)
|
||||
|
||||
@patch(f"{MODULE}.get_universal_table")
|
||||
@patch(f"{MODULE}.get_cis_table")
|
||||
def test_universal_takes_precedence_over_cis(self, mock_cis, mock_universal):
|
||||
"""A CIS framework in universal_frameworks with TableConfig must
|
||||
use the universal renderer, not get_cis_table."""
|
||||
fw = self._make_fw()
|
||||
display_compliance_table(
|
||||
compliance_framework="cis_5.0_aws",
|
||||
universal_frameworks={"cis_5.0_aws": fw},
|
||||
**_COMMON,
|
||||
)
|
||||
mock_universal.assert_called_once()
|
||||
mock_cis.assert_not_called()
|
||||
|
||||
@patch(f"{MODULE}.get_universal_table")
|
||||
@patch(f"{MODULE}.get_cis_table")
|
||||
def test_falls_through_without_table_config(self, mock_cis, mock_universal):
|
||||
"""If the universal framework has no TableConfig, fall through
|
||||
to the legacy elif chain."""
|
||||
fw = self._make_fw()
|
||||
fw.outputs = None
|
||||
display_compliance_table(
|
||||
compliance_framework="cis_5.0_aws",
|
||||
universal_frameworks={"cis_5.0_aws": fw},
|
||||
**_COMMON,
|
||||
)
|
||||
mock_cis.assert_called_once()
|
||||
mock_universal.assert_not_called()
|
||||
|
||||
@patch(f"{MODULE}.get_universal_table")
|
||||
@patch(f"{MODULE}.get_generic_compliance_table")
|
||||
def test_falls_through_when_not_in_universal_dict(
|
||||
self, mock_generic, mock_universal
|
||||
):
|
||||
"""If universal_frameworks is empty, fall through to legacy."""
|
||||
display_compliance_table(
|
||||
compliance_framework="soc2_aws",
|
||||
universal_frameworks={},
|
||||
**_COMMON,
|
||||
)
|
||||
mock_generic.assert_called_once()
|
||||
mock_universal.assert_not_called()
|
||||
@@ -0,0 +1,128 @@
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
|
||||
from freezegun import freeze_time
|
||||
from mock import patch
|
||||
|
||||
from prowler.lib.outputs.compliance.essential_eight.essential_eight_aws import (
|
||||
EssentialEightAWS,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.essential_eight.models import (
|
||||
EssentialEightAWSModel,
|
||||
)
|
||||
from tests.lib.outputs.compliance.fixtures import ESSENTIAL_EIGHT_AWS
|
||||
from tests.lib.outputs.fixtures.fixtures import generate_finding_output
|
||||
from tests.providers.aws.utils import AWS_ACCOUNT_NUMBER, AWS_REGION_EU_WEST_1
|
||||
|
||||
# The fixture's first Requirement maps clause "E8-1.8" (Patch applications,
|
||||
# clause 8: removal of unsupported online services). The second Requirement is
|
||||
# E8-6.1 (Restrict Office macros, clause 1) which has no Checks and is therefore
|
||||
# emitted as a manual row.
|
||||
COMPLIANCE_NAME = "Essential-Eight-Nov 2023"
|
||||
|
||||
|
||||
class TestEssentialEightAWS:
|
||||
def test_output_transform(self):
|
||||
findings = [generate_finding_output(compliance={COMPLIANCE_NAME: "E8-1.8"})]
|
||||
|
||||
output = EssentialEightAWS(findings, ESSENTIAL_EIGHT_AWS)
|
||||
output_data = output.data[0]
|
||||
assert isinstance(output_data, EssentialEightAWSModel)
|
||||
assert output_data.Provider == "aws"
|
||||
assert output_data.Framework == ESSENTIAL_EIGHT_AWS.Framework
|
||||
assert output_data.Name == ESSENTIAL_EIGHT_AWS.Name
|
||||
assert output_data.Description == ESSENTIAL_EIGHT_AWS.Description
|
||||
assert output_data.AccountId == AWS_ACCOUNT_NUMBER
|
||||
assert output_data.Region == AWS_REGION_EU_WEST_1
|
||||
assert output_data.Requirements_Id == "E8-1.8"
|
||||
assert (
|
||||
output_data.Requirements_Description
|
||||
== ESSENTIAL_EIGHT_AWS.Requirements[0].Description
|
||||
)
|
||||
assert output_data.Requirements_Attributes_Section == "1 Patch applications"
|
||||
assert output_data.Requirements_Attributes_MaturityLevel == "ML1"
|
||||
assert output_data.Requirements_Attributes_AssessmentStatus == "Automated"
|
||||
assert output_data.Requirements_Attributes_CloudApplicability == "full"
|
||||
assert (
|
||||
output_data.Requirements_Attributes_MitigatedThreats
|
||||
== "Use of unsupported software, Long-tail vulnerability accumulation"
|
||||
)
|
||||
assert (
|
||||
output_data.Requirements_Attributes_Description
|
||||
== ESSENTIAL_EIGHT_AWS.Requirements[0].Attributes[0].Description
|
||||
)
|
||||
assert output_data.Status == "PASS"
|
||||
assert output_data.StatusExtended == ""
|
||||
assert output_data.ResourceId == ""
|
||||
assert output_data.ResourceName == ""
|
||||
assert output_data.CheckId == "service_test_check_id"
|
||||
assert not output_data.Muted
|
||||
|
||||
def test_manual_requirement(self):
|
||||
findings = [generate_finding_output(compliance={COMPLIANCE_NAME: "E8-1.8"})]
|
||||
output = EssentialEightAWS(findings, ESSENTIAL_EIGHT_AWS)
|
||||
|
||||
# E8-6.1 (macros) has no Checks -> emitted as a manual row, non-applicable
|
||||
manual_rows = [row for row in output.data if row.Status == "MANUAL"]
|
||||
assert len(manual_rows) == 1
|
||||
|
||||
manual = manual_rows[0]
|
||||
assert manual.Provider == "aws"
|
||||
assert manual.AccountId == ""
|
||||
assert manual.Region == ""
|
||||
assert manual.Requirements_Id == "E8-6.1"
|
||||
assert (
|
||||
manual.Requirements_Attributes_Section
|
||||
== "6 Restrict Microsoft Office macros"
|
||||
)
|
||||
assert manual.Requirements_Attributes_MaturityLevel == "ML1"
|
||||
assert manual.Requirements_Attributes_AssessmentStatus == "Manual"
|
||||
assert manual.Requirements_Attributes_CloudApplicability == "non-applicable"
|
||||
assert (
|
||||
manual.Requirements_Attributes_MitigatedThreats
|
||||
== "Macro-based malware delivery"
|
||||
)
|
||||
assert manual.StatusExtended == "Manual check"
|
||||
assert manual.ResourceId == "manual_check"
|
||||
assert manual.ResourceName == "Manual check"
|
||||
assert manual.CheckId == "manual"
|
||||
assert not manual.Muted
|
||||
|
||||
@freeze_time("2025-01-01 00:00:00")
|
||||
@mock.patch(
|
||||
"prowler.lib.outputs.compliance.essential_eight.essential_eight_aws.timestamp",
|
||||
"2025-01-01 00:00:00",
|
||||
)
|
||||
def test_batch_write_data_to_file(self):
|
||||
mock_file = StringIO()
|
||||
findings = [generate_finding_output(compliance={COMPLIANCE_NAME: "E8-1.8"})]
|
||||
output = EssentialEightAWS(findings, ESSENTIAL_EIGHT_AWS)
|
||||
output._file_descriptor = mock_file
|
||||
|
||||
with patch.object(mock_file, "close", return_value=None):
|
||||
output.batch_write_data_to_file()
|
||||
|
||||
mock_file.seek(0)
|
||||
content = mock_file.read()
|
||||
|
||||
# Validate header carries the E8-specific column names
|
||||
first_line = content.split("\r\n", 1)[0]
|
||||
for column in (
|
||||
"REQUIREMENTS_ATTRIBUTES_MATURITYLEVEL",
|
||||
"REQUIREMENTS_ATTRIBUTES_ASSESSMENTSTATUS",
|
||||
"REQUIREMENTS_ATTRIBUTES_CLOUDAPPLICABILITY",
|
||||
"REQUIREMENTS_ATTRIBUTES_MITIGATEDTHREATS",
|
||||
"REQUIREMENTS_ATTRIBUTES_RATIONALESTATEMENT",
|
||||
"REQUIREMENTS_ATTRIBUTES_REMEDIATIONPROCEDURE",
|
||||
"REQUIREMENTS_ATTRIBUTES_AUDITPROCEDURE",
|
||||
):
|
||||
assert column in first_line, f"missing column {column} in CSV header"
|
||||
|
||||
# rows: header + matched + manual
|
||||
rows = [r for r in content.split("\r\n") if r]
|
||||
assert len(rows) == 3
|
||||
assert rows[1].split(";")[0] == "aws"
|
||||
assert "ML1" in rows[1]
|
||||
assert ";PASS;" in rows[1]
|
||||
assert ";MANUAL;" in rows[2]
|
||||
assert ";manual_check;" in rows[2]
|
||||
@@ -7,6 +7,7 @@ from prowler.lib.check.compliance_models import (
|
||||
ENS_Requirement_Attribute,
|
||||
ENS_Requirement_Attribute_Nivel,
|
||||
ENS_Requirement_Attribute_Tipos,
|
||||
EssentialEight_Requirement_Attribute,
|
||||
Generic_Compliance_Requirement_Attribute,
|
||||
ISO27001_2013_Requirement_Attribute,
|
||||
KISA_ISMSP_Requirement_Attribute,
|
||||
@@ -1189,3 +1190,58 @@ CCC_GCP_FIXTURE = Compliance(
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
ESSENTIAL_EIGHT_AWS = Compliance(
|
||||
Framework="Essential-Eight",
|
||||
Name="ASD Essential Eight Maturity Model - Maturity Level One (AWS)",
|
||||
Version="Nov 2023",
|
||||
Provider="AWS",
|
||||
Description="Literal mapping of the Australian Signals Directorate (ASD) Essential Eight Maturity Model ML1 to AWS infrastructure checks.",
|
||||
Requirements=[
|
||||
Compliance_Requirement(
|
||||
Id="E8-1.8",
|
||||
Description="Online services that are no longer supported by vendors are removed.",
|
||||
Attributes=[
|
||||
EssentialEight_Requirement_Attribute(
|
||||
Section="1 Patch applications",
|
||||
MaturityLevel="ML1",
|
||||
AssessmentStatus="Automated",
|
||||
CloudApplicability="full",
|
||||
MitigatedThreats=[
|
||||
"Use of unsupported software",
|
||||
"Long-tail vulnerability accumulation",
|
||||
],
|
||||
Description="Detect and remove unsupported AWS-hosted online services (Lambda runtimes, RDS engines, EKS, Fargate, Kafka, OpenSearch).",
|
||||
RationaleStatement="Unsupported services no longer receive security patches.",
|
||||
ImpactStatement="",
|
||||
RemediationProcedure="Migrate Lambda off deprecated runtimes; remove RDS Extended Support; upgrade EKS.",
|
||||
AuditProcedure="Run all listed checks.",
|
||||
AdditionalInformation="ASD Essential Eight ML1 - Patch applications - clause 8.",
|
||||
References="https://www.cyber.gov.au/resources-business-and-government/essential-cyber-security/essential-eight/essential-eight-maturity-model",
|
||||
)
|
||||
],
|
||||
Checks=["service_test_check_id"],
|
||||
),
|
||||
Compliance_Requirement(
|
||||
Id="E8-6.1",
|
||||
Description="Microsoft Office macros are disabled for users that do not have a demonstrated business requirement.",
|
||||
Attributes=[
|
||||
EssentialEight_Requirement_Attribute(
|
||||
Section="6 Restrict Microsoft Office macros",
|
||||
MaturityLevel="ML1",
|
||||
AssessmentStatus="Manual",
|
||||
CloudApplicability="non-applicable",
|
||||
MitigatedThreats=["Macro-based malware delivery"],
|
||||
Description="Endpoint / Microsoft 365 control. Out of AWS infrastructure scope.",
|
||||
RationaleStatement="Most users never need Office macros.",
|
||||
ImpactStatement="",
|
||||
RemediationProcedure="Disable macros via Group Policy / Intune / M365 admin policies.",
|
||||
AuditProcedure="Manual review of M365 macro policy.",
|
||||
AdditionalInformation="ASD Essential Eight ML1 - Restrict Microsoft Office macros - clause 1. Out of AWS infrastructure scope.",
|
||||
References="https://www.cyber.gov.au/resources-business-and-government/essential-cyber-security/essential-eight/essential-eight-maturity-model",
|
||||
)
|
||||
],
|
||||
Checks=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,730 @@
|
||||
"""Tests for process_universal_compliance_frameworks and --list-compliance fixes.
|
||||
|
||||
Validates that the pre-processing step:
|
||||
- generates both CSV and OCSF outputs for universal frameworks
|
||||
- always generates OCSF (no output-format gate)
|
||||
- skips frameworks without outputs or table_config
|
||||
- skips frameworks not in universal_frameworks
|
||||
- returns the set of processed names for removal from the legacy loop
|
||||
- works across different providers
|
||||
|
||||
Also validates that print_compliance_frameworks and print_compliance_requirements
|
||||
work with universal ComplianceFramework objects (dict checks, None provider).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from prowler.lib.check.check import (
|
||||
print_compliance_frameworks,
|
||||
print_compliance_requirements,
|
||||
)
|
||||
from prowler.lib.check.compliance_models import (
|
||||
AttributeMetadata,
|
||||
ComplianceFramework,
|
||||
OutputsConfig,
|
||||
TableConfig,
|
||||
UniversalComplianceRequirement,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.compliance import (
|
||||
process_universal_compliance_frameworks,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
|
||||
OCSFComplianceOutput,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.universal.universal_output import (
|
||||
UniversalComplianceOutput,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _create_compliance_dir(tmp_path):
|
||||
"""Ensure the compliance/ subdirectory exists before each test."""
|
||||
os.makedirs(tmp_path / "compliance", exist_ok=True)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_finding(check_id, status="PASS", provider="aws"):
|
||||
"""Create a mock Finding with all fields needed by both output classes."""
|
||||
finding = SimpleNamespace()
|
||||
finding.provider = provider
|
||||
finding.account_uid = "123456789012"
|
||||
finding.account_name = "test-account"
|
||||
finding.account_email = ""
|
||||
finding.account_organization_uid = "org-123"
|
||||
finding.account_organization_name = "test-org"
|
||||
finding.account_tags = {"env": "test"}
|
||||
finding.region = "us-east-1"
|
||||
finding.status = status
|
||||
finding.status_extended = f"{check_id} is {status}"
|
||||
finding.resource_uid = f"arn:aws:iam::123456789012:{check_id}"
|
||||
finding.resource_name = check_id
|
||||
finding.resource_details = "some details"
|
||||
finding.resource_metadata = {}
|
||||
finding.resource_tags = {"Name": "test"}
|
||||
finding.partition = "aws"
|
||||
finding.muted = False
|
||||
finding.check_id = check_id
|
||||
finding.uid = "test-finding-uid"
|
||||
finding.timestamp = datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
|
||||
finding.prowler_version = "5.0.0"
|
||||
finding.compliance = {"TestFW-1.0": ["1.1"]}
|
||||
finding.metadata = SimpleNamespace(
|
||||
Provider=provider,
|
||||
CheckID=check_id,
|
||||
CheckTitle=f"Title for {check_id}",
|
||||
CheckType=["test-type"],
|
||||
Description=f"Description for {check_id}",
|
||||
Severity="medium",
|
||||
ServiceName="iam",
|
||||
ResourceType="aws-iam-role",
|
||||
Risk="test-risk",
|
||||
RelatedUrl="https://example.com",
|
||||
Remediation=SimpleNamespace(
|
||||
Recommendation=SimpleNamespace(Text="Fix it", Url="https://fix.com"),
|
||||
),
|
||||
DependsOn=[],
|
||||
RelatedTo=[],
|
||||
Categories=["test"],
|
||||
Notes="",
|
||||
AdditionalURLs=[],
|
||||
)
|
||||
return finding
|
||||
|
||||
|
||||
def _make_universal_framework(name="TestFW", version="1.0", with_table_config=True):
|
||||
"""Build a ComplianceFramework with optional table_config."""
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="Test requirement",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
]
|
||||
metadata = [AttributeMetadata(key="Section", type="str")]
|
||||
outputs = None
|
||||
if with_table_config:
|
||||
outputs = OutputsConfig(table_config=TableConfig(group_by="Section"))
|
||||
return ComplianceFramework(
|
||||
framework=name,
|
||||
name=f"{name} Framework",
|
||||
provider="AWS",
|
||||
version=version,
|
||||
description="Test framework",
|
||||
requirements=reqs,
|
||||
attributes_metadata=metadata,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProcessUniversalComplianceFrameworks:
|
||||
"""Core tests for the extracted pre-processing function."""
|
||||
|
||||
def test_generates_csv_and_ocsf_outputs(self, tmp_path):
|
||||
"""Both CSV and OCSF outputs are appended to generated_outputs."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == {"test_fw_1.0"}
|
||||
assert len(generated["compliance"]) == 2
|
||||
assert isinstance(generated["compliance"][0], UniversalComplianceOutput)
|
||||
assert isinstance(generated["compliance"][1], OCSFComplianceOutput)
|
||||
|
||||
def test_ocsf_always_generated_no_format_gate(self, tmp_path):
|
||||
"""OCSF output is generated regardless of output_formats — no gate."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
ocsf_outputs = [
|
||||
o for o in generated["compliance"] if isinstance(o, OCSFComplianceOutput)
|
||||
]
|
||||
assert len(ocsf_outputs) == 1
|
||||
|
||||
def test_csv_file_written(self, tmp_path):
|
||||
"""CSV file is created with expected content."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
csv_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.csv"
|
||||
assert csv_path.exists()
|
||||
content = csv_path.read_text()
|
||||
assert "PROVIDER" in content
|
||||
assert "REQUIREMENTS_ATTRIBUTES_SECTION" in content
|
||||
|
||||
def test_ocsf_file_written(self, tmp_path):
|
||||
"""OCSF JSON file is created with valid content."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
ocsf_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.ocsf.json"
|
||||
assert ocsf_path.exists()
|
||||
data = json.loads(ocsf_path.read_text())
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
assert data[0]["class_uid"] == 2003
|
||||
|
||||
def test_returns_processed_names(self, tmp_path):
|
||||
"""Returns the set of framework names that were processed."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0", "legacy_fw"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == {"test_fw_1.0"}
|
||||
assert "legacy_fw" not in processed
|
||||
|
||||
|
||||
class TestSkipConditions:
|
||||
"""Tests for frameworks that should NOT be processed."""
|
||||
|
||||
def test_skips_framework_not_in_universal(self, tmp_path):
|
||||
"""Frameworks not in universal_frameworks dict are skipped."""
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"cis_aws_1.4"},
|
||||
universal_frameworks={},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == set()
|
||||
assert len(generated["compliance"]) == 0
|
||||
|
||||
def test_skips_framework_without_outputs(self, tmp_path):
|
||||
"""Frameworks with outputs=None are skipped."""
|
||||
fw = _make_universal_framework(with_table_config=False)
|
||||
# outputs is None since with_table_config=False
|
||||
assert fw.outputs is None
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == set()
|
||||
assert len(generated["compliance"]) == 0
|
||||
|
||||
def test_skips_framework_with_outputs_but_no_table_config(self, tmp_path):
|
||||
"""Frameworks with outputs but table_config=None are skipped."""
|
||||
fw = _make_universal_framework()
|
||||
# Manually set table_config to None while keeping outputs
|
||||
fw.outputs = OutputsConfig(table_config=None)
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == set()
|
||||
assert len(generated["compliance"]) == 0
|
||||
|
||||
def test_empty_input_frameworks(self, tmp_path):
|
||||
"""No processing when input set is empty."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks=set(),
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == set()
|
||||
assert len(generated["compliance"]) == 0
|
||||
|
||||
|
||||
class TestMixedFrameworks:
|
||||
"""Tests with a mix of universal and legacy frameworks."""
|
||||
|
||||
def test_only_universal_processed_legacy_untouched(self, tmp_path):
|
||||
"""Only universal frameworks are processed; legacy names are not returned."""
|
||||
universal_fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
all_frameworks = {"test_fw_1.0", "cis_aws_1.4", "nist_800_53_aws"}
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks=all_frameworks,
|
||||
universal_frameworks={"test_fw_1.0": universal_fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == {"test_fw_1.0"}
|
||||
# 2 outputs for the one universal framework (CSV + OCSF)
|
||||
assert len(generated["compliance"]) == 2
|
||||
|
||||
def test_removal_from_input_set(self, tmp_path):
|
||||
"""Caller can subtract processed set from input to get legacy-only frameworks."""
|
||||
universal_fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
input_frameworks = {"test_fw_1.0", "cis_aws_1.4", "nist_800_53_aws"}
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks=input_frameworks,
|
||||
universal_frameworks={"test_fw_1.0": universal_fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
remaining = input_frameworks - processed
|
||||
assert remaining == {"cis_aws_1.4", "nist_800_53_aws"}
|
||||
|
||||
def test_multiple_universal_frameworks(self, tmp_path):
|
||||
"""Multiple universal frameworks each get CSV + OCSF."""
|
||||
fw1 = _make_universal_framework(name="FW1", version="1.0")
|
||||
fw2 = _make_universal_framework(name="FW2", version="2.0")
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"fw1_1.0", "fw2_2.0", "legacy"},
|
||||
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == {"fw1_1.0", "fw2_2.0"}
|
||||
# 2 frameworks × 2 outputs each = 4
|
||||
assert len(generated["compliance"]) == 4
|
||||
csv_outputs = [
|
||||
o
|
||||
for o in generated["compliance"]
|
||||
if isinstance(o, UniversalComplianceOutput)
|
||||
]
|
||||
ocsf_outputs = [
|
||||
o for o in generated["compliance"] if isinstance(o, OCSFComplianceOutput)
|
||||
]
|
||||
assert len(csv_outputs) == 2
|
||||
assert len(ocsf_outputs) == 2
|
||||
|
||||
|
||||
class TestProviderVariants:
|
||||
"""Verify the function works for different providers."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
[
|
||||
"aws",
|
||||
"azure",
|
||||
"gcp",
|
||||
"kubernetes",
|
||||
"m365",
|
||||
"github",
|
||||
"oraclecloud",
|
||||
"alibabacloud",
|
||||
"nhn",
|
||||
],
|
||||
)
|
||||
def test_all_providers_produce_outputs(self, tmp_path, provider):
|
||||
"""Each provider generates CSV + OCSF when given a universal framework."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a", provider=provider)],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider=provider,
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == {"test_fw_1.0"}
|
||||
assert len(generated["compliance"]) == 2
|
||||
assert isinstance(generated["compliance"][0], UniversalComplianceOutput)
|
||||
assert isinstance(generated["compliance"][1], OCSFComplianceOutput)
|
||||
|
||||
|
||||
class TestEmptyFindings:
|
||||
"""Test behavior when there are no findings."""
|
||||
|
||||
def test_still_processed_with_empty_findings(self, tmp_path):
|
||||
"""Framework is still marked as processed even with no findings."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
processed = process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
assert processed == {"test_fw_1.0"}
|
||||
# Outputs are still appended (they'll just have empty data)
|
||||
assert len(generated["compliance"]) == 2
|
||||
|
||||
|
||||
class TestFilePaths:
|
||||
"""Verify correct file path construction."""
|
||||
|
||||
def test_csv_path_format(self, tmp_path):
|
||||
"""CSV output has the correct file path."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"csa_ccm_4.0"},
|
||||
universal_frameworks={"csa_ccm_4.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_report",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
csv_output = generated["compliance"][0]
|
||||
assert csv_output.file_path == (
|
||||
f"{tmp_path}/compliance/prowler_report_csa_ccm_4.0.csv"
|
||||
)
|
||||
|
||||
def test_ocsf_path_format(self, tmp_path):
|
||||
"""OCSF output has the correct file path."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"csa_ccm_4.0"},
|
||||
universal_frameworks={"csa_ccm_4.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_report",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
ocsf_output = generated["compliance"][1]
|
||||
assert ocsf_output.file_path == (
|
||||
f"{tmp_path}/compliance/prowler_report_csa_ccm_4.0.ocsf.json"
|
||||
)
|
||||
|
||||
|
||||
# ── Tests for --list-compliance fix ──────────────────────────────────
|
||||
|
||||
|
||||
def _make_legacy_compliance():
|
||||
"""Create a mock legacy Compliance-like object with the expected attributes."""
|
||||
return SimpleNamespace(
|
||||
Framework="CIS",
|
||||
Provider="AWS",
|
||||
Version="1.4",
|
||||
Requirements=[
|
||||
SimpleNamespace(
|
||||
Id="2.1.3",
|
||||
Description="Ensure MFA Delete is enabled",
|
||||
Checks=["s3_bucket_mfa_delete"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestPrintComplianceFrameworks:
|
||||
"""Tests for print_compliance_frameworks with universal frameworks."""
|
||||
|
||||
def test_includes_universal_frameworks(self, capsys):
|
||||
"""Universal frameworks appear in the listing."""
|
||||
legacy = {"cis_1.4_aws": _make_legacy_compliance()}
|
||||
universal = {"csa_ccm_4.0": _make_universal_framework()}
|
||||
merged = {**legacy, **universal}
|
||||
|
||||
print_compliance_frameworks(merged)
|
||||
captured = capsys.readouterr().out
|
||||
|
||||
assert "cis_1.4_aws" in captured
|
||||
assert "csa_ccm_4.0" in captured
|
||||
|
||||
def test_count_includes_both(self, capsys):
|
||||
"""Framework count includes both legacy and universal."""
|
||||
legacy = {"cis_1.4_aws": _make_legacy_compliance()}
|
||||
universal = {"csa_ccm_4.0": _make_universal_framework()}
|
||||
merged = {**legacy, **universal}
|
||||
|
||||
print_compliance_frameworks(merged)
|
||||
captured = capsys.readouterr().out
|
||||
|
||||
assert "2" in captured
|
||||
|
||||
def test_universal_only(self, capsys):
|
||||
"""Works when only universal frameworks are present."""
|
||||
universal = {"csa_ccm_4.0": _make_universal_framework()}
|
||||
|
||||
print_compliance_frameworks(universal)
|
||||
captured = capsys.readouterr().out
|
||||
|
||||
assert "csa_ccm_4.0" in captured
|
||||
assert "1" in captured
|
||||
|
||||
|
||||
class TestPrintComplianceRequirements:
|
||||
"""Tests for print_compliance_requirements with universal frameworks."""
|
||||
|
||||
def test_list_checks_universal_framework(self, capsys):
|
||||
"""Requirements with dict checks are printed correctly."""
|
||||
fw = _make_universal_framework()
|
||||
all_fw = {"test_fw_1.0": fw}
|
||||
|
||||
print_compliance_requirements(all_fw, ["test_fw_1.0"])
|
||||
captured = capsys.readouterr().out
|
||||
|
||||
assert "1.1" in captured
|
||||
assert "check_a" in captured
|
||||
|
||||
def test_dict_checks_universal_framework(self, capsys):
|
||||
"""Requirements with dict checks show provider-prefixed checks."""
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="A&A-01",
|
||||
description="Audit & Assurance",
|
||||
attributes={"Section": "A&A"},
|
||||
checks={"aws": ["check_a", "check_b"], "azure": ["check_c"]},
|
||||
),
|
||||
]
|
||||
fw = ComplianceFramework(
|
||||
framework="CSA_CCM",
|
||||
name="CSA CCM 4.0",
|
||||
version="4.0",
|
||||
description="Cloud Controls Matrix",
|
||||
requirements=reqs,
|
||||
)
|
||||
all_fw = {"csa_ccm_4.0": fw}
|
||||
|
||||
print_compliance_requirements(all_fw, ["csa_ccm_4.0"])
|
||||
captured = capsys.readouterr().out
|
||||
|
||||
assert "A&A-01" in captured
|
||||
assert "[aws] check_a" in captured
|
||||
assert "[aws] check_b" in captured
|
||||
assert "[azure] check_c" in captured
|
||||
|
||||
def test_none_provider_shows_multi_provider(self, capsys):
|
||||
"""Frameworks with provider=None show 'Multi-provider'."""
|
||||
fw = ComplianceFramework(
|
||||
framework="CSA_CCM",
|
||||
name="CSA CCM 4.0",
|
||||
version="4.0",
|
||||
description="Cloud Controls Matrix",
|
||||
requirements=[
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
],
|
||||
)
|
||||
all_fw = {"csa_ccm_4.0": fw}
|
||||
|
||||
print_compliance_requirements(all_fw, ["csa_ccm_4.0"])
|
||||
captured = capsys.readouterr().out
|
||||
|
||||
assert "Multi-provider" in captured
|
||||
|
||||
|
||||
# ── Idempotency tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIdempotency:
|
||||
"""The function must be safe to invoke multiple times for the same
|
||||
framework. Repeated calls must reuse writers tracked in
|
||||
``generated_outputs["compliance"]`` instead of recreating them.
|
||||
|
||||
This guards against:
|
||||
- duplicate writer entries in generated_outputs (regular pipeline
|
||||
treats one writer per framework)
|
||||
- the OCSF append-bug where a second writer would emit
|
||||
``[...]<new>...]`` and break the JSON array.
|
||||
"""
|
||||
|
||||
def test_second_call_does_not_duplicate_writers(self, tmp_path):
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
kwargs = dict(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
first = process_universal_compliance_frameworks(**kwargs)
|
||||
first_count = len(generated["compliance"])
|
||||
second = process_universal_compliance_frameworks(**kwargs)
|
||||
second_count = len(generated["compliance"])
|
||||
|
||||
assert first == {"test_fw_1.0"}
|
||||
assert second == {"test_fw_1.0"} # still reported as processed
|
||||
assert first_count == 2 # CSV + OCSF
|
||||
assert second_count == 2 # NO duplication
|
||||
|
||||
def test_second_call_keeps_ocsf_json_valid(self, tmp_path):
|
||||
"""End-to-end: after two calls the OCSF JSON file must still be
|
||||
a single, valid JSON array — not the broken ``[...]...]`` form."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
kwargs = dict(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
process_universal_compliance_frameworks(**kwargs)
|
||||
process_universal_compliance_frameworks(**kwargs)
|
||||
|
||||
ocsf_path = tmp_path / "compliance" / "prowler_output_test_fw_1.0.ocsf.json"
|
||||
data = json.loads(ocsf_path.read_text()) # Will raise on invalid JSON
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
def test_reuses_existing_writer_object(self, tmp_path):
|
||||
"""The CSV/OCSF writer instances appended on first call must be
|
||||
the SAME objects after a second call — not fresh ones."""
|
||||
fw = _make_universal_framework()
|
||||
generated = {"compliance": []}
|
||||
kwargs = dict(
|
||||
input_compliance_frameworks={"test_fw_1.0"},
|
||||
universal_frameworks={"test_fw_1.0": fw},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="prowler_output",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
|
||||
process_universal_compliance_frameworks(**kwargs)
|
||||
first_writers = list(generated["compliance"])
|
||||
process_universal_compliance_frameworks(**kwargs)
|
||||
second_writers = list(generated["compliance"])
|
||||
|
||||
# Same identity, same length — reused, not recreated.
|
||||
assert len(first_writers) == len(second_writers)
|
||||
for a, b in zip(first_writers, second_writers):
|
||||
assert a is b
|
||||
|
||||
def test_idempotency_across_mixed_frameworks(self, tmp_path):
|
||||
"""When the second call adds a new framework, the new one is
|
||||
created while existing ones are NOT recreated."""
|
||||
fw1 = _make_universal_framework(name="FW1", version="1.0")
|
||||
fw2 = _make_universal_framework(name="FW2", version="2.0")
|
||||
generated = {"compliance": []}
|
||||
|
||||
# First call: only FW1
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"fw1_1.0"},
|
||||
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
first_writers = list(generated["compliance"])
|
||||
assert len(first_writers) == 2
|
||||
|
||||
# Second call: includes both. FW1 must be reused, FW2 created fresh.
|
||||
process_universal_compliance_frameworks(
|
||||
input_compliance_frameworks={"fw1_1.0", "fw2_2.0"},
|
||||
universal_frameworks={"fw1_1.0": fw1, "fw2_2.0": fw2},
|
||||
finding_outputs=[_make_finding("check_a")],
|
||||
output_directory=str(tmp_path),
|
||||
output_filename="out",
|
||||
provider="aws",
|
||||
generated_outputs=generated,
|
||||
)
|
||||
second_writers = list(generated["compliance"])
|
||||
assert len(second_writers) == 4 # 2 (FW1 reused) + 2 new (FW2)
|
||||
# FW1 writer instances unchanged
|
||||
assert second_writers[0] is first_writers[0]
|
||||
assert second_writers[1] is first_writers[1]
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
|
||||
from py_ocsf_models.events.base_event import StatusID as EventStatusID
|
||||
from py_ocsf_models.events.findings.compliance_finding import ComplianceFinding
|
||||
from py_ocsf_models.events.findings.compliance_finding_type_id import (
|
||||
ComplianceFindingTypeID,
|
||||
@@ -18,6 +19,7 @@ from prowler.lib.check.compliance_models import (
|
||||
)
|
||||
from prowler.lib.outputs.compliance.universal.ocsf_compliance import (
|
||||
OCSFComplianceOutput,
|
||||
_sanitize_resource_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -473,3 +475,159 @@ class TestOCSFComplianceOutput:
|
||||
cf = output.data[0]
|
||||
assert cf.unmapped["requirement_attributes"]["section"] == "Logging"
|
||||
assert "internal_note" not in cf.unmapped["requirement_attributes"]
|
||||
|
||||
|
||||
class TestSanitizeResourceData:
|
||||
"""Unit tests for the _sanitize_resource_data helper.
|
||||
|
||||
Service resources may carry non-JSON-serializable objects (e.g. raw
|
||||
Pydantic models such as ``Trail`` or ``LifecyclePolicy``). The helper
|
||||
must convert them so the resulting ComplianceFinding can be serialized.
|
||||
"""
|
||||
|
||||
def test_dict_passthrough(self):
|
||||
result = _sanitize_resource_data("details", {"a": 1, "b": "two"})
|
||||
assert result == {"details": "details", "metadata": {"a": 1, "b": "two"}}
|
||||
|
||||
def test_none_metadata(self):
|
||||
result = _sanitize_resource_data("details", None)
|
||||
assert result == {"details": "details", "metadata": None}
|
||||
|
||||
def test_pydantic_v2_model_dump(self):
|
||||
class FakeV2Model:
|
||||
def model_dump(self):
|
||||
return {"name": "trail-1", "region": "us-east-1"}
|
||||
|
||||
result = _sanitize_resource_data("d", {"trail": FakeV2Model()})
|
||||
assert result["metadata"]["trail"] == {
|
||||
"name": "trail-1",
|
||||
"region": "us-east-1",
|
||||
}
|
||||
|
||||
def test_pydantic_v1_dict(self):
|
||||
class FakeV1Model:
|
||||
def dict(self):
|
||||
return {"name": "policy-1", "schedule": "daily"}
|
||||
|
||||
result = _sanitize_resource_data("d", {"policy": FakeV1Model()})
|
||||
assert result["metadata"]["policy"] == {
|
||||
"name": "policy-1",
|
||||
"schedule": "daily",
|
||||
}
|
||||
|
||||
def test_nested_pydantic_in_list(self):
|
||||
class FakeModel:
|
||||
def model_dump(self):
|
||||
return {"id": "x"}
|
||||
|
||||
result = _sanitize_resource_data("d", {"items": [FakeModel(), FakeModel()]})
|
||||
assert result["metadata"]["items"] == [{"id": "x"}, {"id": "x"}]
|
||||
|
||||
def test_nested_dict_recursion(self):
|
||||
class FakeInner:
|
||||
def model_dump(self):
|
||||
return {"k": "v"}
|
||||
|
||||
result = _sanitize_resource_data(
|
||||
"d", {"outer": {"inner": FakeInner(), "x": [1, 2]}}
|
||||
)
|
||||
assert result["metadata"]["outer"]["inner"] == {"k": "v"}
|
||||
assert result["metadata"]["outer"]["x"] == [1, 2]
|
||||
|
||||
def test_tuple_to_list(self):
|
||||
result = _sanitize_resource_data("d", {"t": (1, 2, "three")})
|
||||
assert result["metadata"]["t"] == [1, 2, "three"]
|
||||
|
||||
def test_non_string_dict_keys_coerced(self):
|
||||
result = _sanitize_resource_data("d", {1: "a", 2: "b"})
|
||||
assert result["metadata"] == {"1": "a", "2": "b"}
|
||||
|
||||
def test_unknown_object_falls_back_to_str(self):
|
||||
class Opaque:
|
||||
def __str__(self):
|
||||
return "opaque-repr"
|
||||
|
||||
result = _sanitize_resource_data("d", {"thing": Opaque()})
|
||||
assert result["metadata"]["thing"] == "opaque-repr"
|
||||
|
||||
def test_circular_reference_falls_back_to_empty(self):
|
||||
a = {}
|
||||
a["self"] = a
|
||||
# json.dumps raises ValueError on recursion → fallback to empty metadata
|
||||
result = _sanitize_resource_data("d", a)
|
||||
assert result == {"details": "d", "metadata": {}}
|
||||
|
||||
def test_serializes_via_full_finding_pipeline(self):
|
||||
"""End-to-end: a finding with a non-serializable resource_metadata
|
||||
produces a JSON-serializable ComplianceFinding."""
|
||||
|
||||
class TrailLike:
|
||||
def __init__(self):
|
||||
self.name = "trail-A"
|
||||
self.kms_key_id = "arn:aws:kms:..."
|
||||
|
||||
def model_dump(self):
|
||||
return {"name": self.name, "kms_key_id": self.kms_key_id}
|
||||
|
||||
finding = _make_finding("check_a")
|
||||
finding.resource_metadata = {"trail": TrailLike()}
|
||||
req = _simple_requirement()
|
||||
fw = _make_framework([req])
|
||||
|
||||
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
|
||||
|
||||
# Serialize the resulting ComplianceFinding — must NOT raise
|
||||
cf = output.data[0]
|
||||
if hasattr(cf, "model_dump_json"):
|
||||
json_output = cf.model_dump_json(exclude_none=True)
|
||||
else:
|
||||
json_output = cf.json(exclude_none=True)
|
||||
payload = json.loads(json_output)
|
||||
|
||||
# Confirm the trail object made it through as a plain dict
|
||||
assert payload["resources"][0]["data"]["metadata"]["trail"]["name"] == "trail-A"
|
||||
|
||||
|
||||
class TestEventStatusInline:
|
||||
"""Tests for the inlined event_status logic that replaced
|
||||
OCSF.get_finding_status_id() to break the cyclic import."""
|
||||
|
||||
def test_unmuted_finding_status_new(self):
|
||||
finding = _make_finding("check_a")
|
||||
finding.muted = False
|
||||
req = _simple_requirement()
|
||||
fw = _make_framework([req])
|
||||
|
||||
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
|
||||
cf = output.data[0]
|
||||
|
||||
assert cf.status_id == EventStatusID.New.value
|
||||
assert cf.status == EventStatusID.New.name
|
||||
|
||||
def test_muted_finding_status_suppressed(self):
|
||||
finding = _make_finding("check_a")
|
||||
finding.muted = True
|
||||
req = _simple_requirement()
|
||||
fw = _make_framework([req])
|
||||
|
||||
output = OCSFComplianceOutput(findings=[finding], framework=fw, provider="aws")
|
||||
cf = output.data[0]
|
||||
|
||||
assert cf.status_id == EventStatusID.Suppressed.value
|
||||
assert cf.status == EventStatusID.Suppressed.name
|
||||
|
||||
|
||||
class TestNoTopLevelOCSFImport:
|
||||
"""Regression test: the top-level OCSF/Finding imports were removed
|
||||
to break the CodeQL cyclic-import warnings. Ensure they stay out of
|
||||
the runtime namespace of the module (TYPE_CHECKING block only)."""
|
||||
|
||||
def test_finding_not_in_runtime_namespace(self):
|
||||
import prowler.lib.outputs.compliance.universal.ocsf_compliance as mod
|
||||
|
||||
assert "Finding" not in dir(mod)
|
||||
|
||||
def test_ocsf_class_not_imported(self):
|
||||
import prowler.lib.outputs.compliance.universal.ocsf_compliance as mod
|
||||
|
||||
assert "OCSF" not in dir(mod)
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from prowler.lib.check.compliance_models import (
|
||||
AttributeMetadata,
|
||||
ComplianceFramework,
|
||||
OutputFormats,
|
||||
OutputsConfig,
|
||||
TableConfig,
|
||||
UniversalComplianceRequirement,
|
||||
)
|
||||
from prowler.lib.outputs.compliance.universal.universal_output import (
|
||||
UniversalComplianceOutput,
|
||||
)
|
||||
|
||||
|
||||
def _make_finding(check_id, status="PASS", compliance_map=None):
|
||||
"""Create a mock Finding for output tests."""
|
||||
finding = SimpleNamespace()
|
||||
finding.provider = "aws"
|
||||
finding.account_uid = "123456789012"
|
||||
finding.account_name = "test-account"
|
||||
finding.region = "us-east-1"
|
||||
finding.status = status
|
||||
finding.status_extended = f"{check_id} is {status}"
|
||||
finding.resource_uid = f"arn:aws:iam::123456789012:{check_id}"
|
||||
finding.resource_name = check_id
|
||||
finding.muted = False
|
||||
finding.check_id = check_id
|
||||
finding.metadata = SimpleNamespace(
|
||||
Provider="aws",
|
||||
CheckID=check_id,
|
||||
Severity="medium",
|
||||
)
|
||||
finding.compliance = compliance_map or {}
|
||||
return finding
|
||||
|
||||
|
||||
def _make_framework(requirements, attrs_metadata=None, table_config=None):
|
||||
return ComplianceFramework(
|
||||
framework="TestFW",
|
||||
name="Test Framework",
|
||||
provider="AWS",
|
||||
version="1.0",
|
||||
description="Test framework",
|
||||
requirements=requirements,
|
||||
attributes_metadata=attrs_metadata,
|
||||
outputs=OutputsConfig(table_config=table_config) if table_config else None,
|
||||
)
|
||||
|
||||
|
||||
class TestDynamicCSVColumns:
|
||||
def test_columns_match_metadata(self, tmp_path):
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM", "SubSection": "Auth"},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
AttributeMetadata(key="SubSection", type="str"),
|
||||
]
|
||||
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
|
||||
assert len(output.data) == 1
|
||||
row_dict = output.data[0].dict()
|
||||
assert "Requirements_Attributes_Section" in row_dict
|
||||
assert "Requirements_Attributes_SubSection" in row_dict
|
||||
assert row_dict["Requirements_Attributes_Section"] == "IAM"
|
||||
assert row_dict["Requirements_Attributes_SubSection"] == "Auth"
|
||||
|
||||
|
||||
class TestManualRequirements:
|
||||
def test_manual_status(self, tmp_path):
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
UniversalComplianceRequirement(
|
||||
id="manual-1",
|
||||
description="manual check",
|
||||
attributes={"Section": "Governance"},
|
||||
checks={},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
]
|
||||
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
|
||||
# Should have 1 real finding + 1 manual
|
||||
assert len(output.data) == 2
|
||||
manual_rows = [r for r in output.data if r.dict()["Status"] == "MANUAL"]
|
||||
assert len(manual_rows) == 1
|
||||
assert manual_rows[0].dict()["Requirements_Id"] == "manual-1"
|
||||
assert manual_rows[0].dict()["ResourceId"] == "manual_check"
|
||||
|
||||
|
||||
class TestMITREExtraColumns:
|
||||
def test_mitre_columns_present(self, tmp_path):
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="T1190",
|
||||
description="Exploit",
|
||||
attributes={},
|
||||
checks={"aws": ["check_a"]},
|
||||
tactics=["Initial Access"],
|
||||
sub_techniques=[],
|
||||
platforms=["IaaS"],
|
||||
technique_url="https://attack.mitre.org/techniques/T1190/",
|
||||
),
|
||||
]
|
||||
fw = _make_framework(reqs, None, TableConfig(group_by="_Tactics"))
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"TestFW-1.0": ["T1190"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
|
||||
assert len(output.data) == 1
|
||||
row_dict = output.data[0].dict()
|
||||
assert "Requirements_Tactics" in row_dict
|
||||
assert row_dict["Requirements_Tactics"] == "Initial Access"
|
||||
assert "Requirements_TechniqueURL" in row_dict
|
||||
|
||||
|
||||
class TestCSVFileWrite:
|
||||
def test_batch_write(self, tmp_path):
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
]
|
||||
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
output.batch_write_data_to_file()
|
||||
|
||||
# Verify file was created and has content
|
||||
with open(filepath, "r") as f:
|
||||
content = f.read()
|
||||
assert "PROVIDER" in content # Headers are uppercase
|
||||
assert "REQUIREMENTS_ATTRIBUTES_SECTION" in content
|
||||
assert "IAM" in content
|
||||
|
||||
|
||||
class TestNoFindings:
|
||||
def test_empty_findings_no_data(self, tmp_path):
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
]
|
||||
fw = _make_framework(reqs, None, TableConfig(group_by="Section"))
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=[],
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
assert len(output.data) == 0
|
||||
|
||||
|
||||
class TestMultiProviderOutput:
|
||||
def test_dict_checks_filtered_by_provider(self, tmp_path):
|
||||
"""Only checks for the given provider appear in CSV output."""
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={"aws": ["check_a"], "azure": ["check_b"]},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
]
|
||||
fw = ComplianceFramework(
|
||||
framework="MultiCloud",
|
||||
name="Multi",
|
||||
version="1.0",
|
||||
description="Test multi-provider",
|
||||
requirements=reqs,
|
||||
attributes_metadata=metadata,
|
||||
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
|
||||
)
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"MultiCloud-1.0": ["1.1"]}),
|
||||
_make_finding("check_b", "FAIL", {"MultiCloud-1.0": ["1.1"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
provider="aws",
|
||||
)
|
||||
|
||||
# Only check_a should match (it's the AWS check)
|
||||
assert len(output.data) == 1
|
||||
row_dict = output.data[0].dict()
|
||||
assert row_dict["Requirements_Attributes_Section"] == "IAM"
|
||||
|
||||
def test_no_provider_includes_all(self, tmp_path):
|
||||
"""Without provider filter, all checks from all providers are included."""
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={"aws": ["check_a"], "azure": ["check_b"]},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
]
|
||||
fw = ComplianceFramework(
|
||||
framework="MultiCloud",
|
||||
name="Multi",
|
||||
version="1.0",
|
||||
description="Test multi-provider",
|
||||
requirements=reqs,
|
||||
attributes_metadata=metadata,
|
||||
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
|
||||
)
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"MultiCloud-1.0": ["1.1"]}),
|
||||
_make_finding("check_b", "FAIL", {"MultiCloud-1.0": ["1.1"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
|
||||
# Both checks should be included without provider filter
|
||||
assert len(output.data) == 2
|
||||
|
||||
def test_empty_dict_checks_is_manual(self, tmp_path):
|
||||
"""Requirement with empty dict checks is treated as manual."""
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="manual-1",
|
||||
description="manual check",
|
||||
attributes={"Section": "Governance"},
|
||||
checks={},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
]
|
||||
fw = ComplianceFramework(
|
||||
framework="MultiCloud",
|
||||
name="Multi",
|
||||
version="1.0",
|
||||
description="Test",
|
||||
requirements=reqs,
|
||||
attributes_metadata=metadata,
|
||||
outputs=OutputsConfig(table_config=TableConfig(group_by="Section")),
|
||||
)
|
||||
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=[_make_finding("other_check", "PASS", {})],
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
provider="aws",
|
||||
)
|
||||
|
||||
manual_rows = [r for r in output.data if r.dict()["Status"] == "MANUAL"]
|
||||
assert len(manual_rows) == 1
|
||||
assert manual_rows[0].dict()["Requirements_Id"] == "manual-1"
|
||||
|
||||
|
||||
class TestCSVExclude:
|
||||
def test_csv_false_excludes_column(self, tmp_path):
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM", "Internal": "hidden"},
|
||||
checks={"aws": ["check_a"]},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(
|
||||
key="Section", type="str", output_formats=OutputFormats(csv=True)
|
||||
),
|
||||
AttributeMetadata(
|
||||
key="Internal", type="str", output_formats=OutputFormats(csv=False)
|
||||
),
|
||||
]
|
||||
fw = _make_framework(reqs, metadata, TableConfig(group_by="Section"))
|
||||
|
||||
findings = [
|
||||
_make_finding("check_a", "PASS", {"TestFW-1.0": ["1.1"]}),
|
||||
]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
)
|
||||
|
||||
row_dict = output.data[0].dict()
|
||||
assert "Requirements_Attributes_Section" in row_dict
|
||||
assert "Requirements_Attributes_Internal" not in row_dict
|
||||
|
||||
|
||||
def _make_provider_finding(provider, check_id="check_a", status="PASS"):
|
||||
"""Create a mock Finding with a specific provider."""
|
||||
finding = _make_finding(check_id, status, {"TestFW-1.0": ["1.1"]})
|
||||
finding.provider = provider
|
||||
return finding
|
||||
|
||||
|
||||
def _simple_framework():
|
||||
all_providers = [
|
||||
"aws",
|
||||
"azure",
|
||||
"gcp",
|
||||
"kubernetes",
|
||||
"m365",
|
||||
"github",
|
||||
"oraclecloud",
|
||||
"alibabacloud",
|
||||
"nhn",
|
||||
"unknown",
|
||||
]
|
||||
reqs = [
|
||||
UniversalComplianceRequirement(
|
||||
id="1.1",
|
||||
description="test",
|
||||
attributes={"Section": "IAM"},
|
||||
checks={p: ["check_a"] for p in all_providers},
|
||||
),
|
||||
]
|
||||
metadata = [
|
||||
AttributeMetadata(key="Section", type="str"),
|
||||
]
|
||||
return _make_framework(reqs, metadata, TableConfig(group_by="Section"))
|
||||
|
||||
|
||||
class TestProviderHeaders:
|
||||
def test_aws_headers(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("aws")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
provider="aws",
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "AccountId" in row_dict
|
||||
assert "Region" in row_dict
|
||||
assert row_dict["AccountId"] == "123456789012"
|
||||
assert row_dict["Region"] == "us-east-1"
|
||||
|
||||
def test_azure_headers(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("azure")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
provider="azure",
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "SubscriptionId" in row_dict
|
||||
assert "Location" in row_dict
|
||||
assert row_dict["SubscriptionId"] == "123456789012"
|
||||
assert row_dict["Location"] == "us-east-1"
|
||||
|
||||
def test_gcp_headers(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("gcp")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
provider="gcp",
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "ProjectId" in row_dict
|
||||
assert "Location" in row_dict
|
||||
assert row_dict["ProjectId"] == "123456789012"
|
||||
|
||||
def test_kubernetes_headers(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("kubernetes")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
provider="kubernetes",
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "Context" in row_dict
|
||||
assert "Namespace" in row_dict
|
||||
# Kubernetes Context maps to account_name
|
||||
assert row_dict["Context"] == "test-account"
|
||||
assert row_dict["Namespace"] == "us-east-1"
|
||||
|
||||
def test_github_headers(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("github")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
provider="github",
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "Account_Name" in row_dict
|
||||
assert "Account_Id" in row_dict
|
||||
# GitHub: Account_Name (pos 3) from account_name, Account_Id (pos 4) from account_uid
|
||||
assert row_dict["Account_Name"] == "test-account"
|
||||
assert row_dict["Account_Id"] == "123456789012"
|
||||
# Verify column order matches legacy (Account_Name before Account_Id)
|
||||
keys = list(row_dict.keys())
|
||||
assert keys.index("Account_Name") < keys.index("Account_Id")
|
||||
|
||||
def test_unknown_provider_defaults(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("unknown")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
provider="unknown",
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "AccountId" in row_dict
|
||||
assert "Region" in row_dict
|
||||
|
||||
def test_none_provider_defaults(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("aws")]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / "test.csv"),
|
||||
)
|
||||
row_dict = output.data[0].dict()
|
||||
assert "AccountId" in row_dict
|
||||
assert "Region" in row_dict
|
||||
|
||||
def test_csv_write_azure_headers(self, tmp_path):
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding("azure")]
|
||||
filepath = str(tmp_path / "test.csv")
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=filepath,
|
||||
provider="azure",
|
||||
)
|
||||
output.batch_write_data_to_file()
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
content = f.read()
|
||||
assert "SUBSCRIPTIONID" in content
|
||||
assert "LOCATION" in content
|
||||
# Should NOT have the default AccountId/Region headers
|
||||
assert "ACCOUNTID" not in content
|
||||
|
||||
def test_column_order_matches_legacy(self, tmp_path):
|
||||
"""Verify that the base column order matches the legacy per-provider models.
|
||||
|
||||
Legacy models all define: Provider, Description, <col3>, <col4>, AssessmentDate, ...
|
||||
The universal output must preserve this exact order for backward compatibility.
|
||||
"""
|
||||
# Expected column order per provider (positions 3 and 4 after Provider, Description)
|
||||
legacy_order = {
|
||||
"aws": ("AccountId", "Region"),
|
||||
"azure": ("SubscriptionId", "Location"),
|
||||
"gcp": ("ProjectId", "Location"),
|
||||
"kubernetes": ("Context", "Namespace"),
|
||||
"m365": ("TenantId", "Location"),
|
||||
"github": ("Account_Name", "Account_Id"),
|
||||
"oraclecloud": ("TenancyId", "Region"),
|
||||
"alibabacloud": ("AccountId", "Region"),
|
||||
"nhn": ("AccountId", "Region"),
|
||||
}
|
||||
|
||||
for provider_name, (expected_col3, expected_col4) in legacy_order.items():
|
||||
fw = _simple_framework()
|
||||
findings = [_make_provider_finding(provider_name)]
|
||||
output = UniversalComplianceOutput(
|
||||
findings=findings,
|
||||
framework=fw,
|
||||
file_path=str(tmp_path / f"test_{provider_name}.csv"),
|
||||
provider=provider_name,
|
||||
)
|
||||
keys = list(output.data[0].dict().keys())
|
||||
assert keys[0] == "Provider", f"{provider_name}: col 1 should be Provider"
|
||||
assert (
|
||||
keys[1] == "Description"
|
||||
), f"{provider_name}: col 2 should be Description"
|
||||
assert (
|
||||
keys[2] == expected_col3
|
||||
), f"{provider_name}: col 3 should be {expected_col3}, got {keys[2]}"
|
||||
assert (
|
||||
keys[3] == expected_col4
|
||||
), f"{provider_name}: col 4 should be {expected_col4}, got {keys[3]}"
|
||||
assert (
|
||||
keys[4] == "AssessmentDate"
|
||||
), f"{provider_name}: col 5 should be AssessmentDate"
|
||||
@@ -21,6 +21,7 @@ from prowler.providers.aws.config import (
|
||||
AWS_STS_GLOBAL_ENDPOINT_REGION,
|
||||
BOTO3_USER_AGENT_EXTRA,
|
||||
ROLE_SESSION_NAME,
|
||||
get_default_session_config,
|
||||
)
|
||||
from prowler.providers.aws.exceptions.exceptions import (
|
||||
AWSArgumentTypeValidationError,
|
||||
@@ -2242,6 +2243,12 @@ aws:
|
||||
assert session_config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
|
||||
assert session_config.retries == {"max_attempts": 10, "mode": "standard"}
|
||||
|
||||
def test_get_default_session_config(self):
|
||||
config = get_default_session_config()
|
||||
|
||||
assert config.user_agent_extra == BOTO3_USER_AGENT_EXTRA
|
||||
assert config.retries == {"max_attempts": 3, "mode": "standard"}
|
||||
|
||||
@mock_aws
|
||||
@patch(
|
||||
"prowler.lib.check.utils.recover_checks_from_provider",
|
||||
|
||||
@@ -4,6 +4,8 @@ import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from prowler.providers.aws.aws_provider import AwsProvider
|
||||
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
|
||||
from prowler.providers.aws.lib.organizations.organizations import (
|
||||
_get_ou_metadata,
|
||||
get_organizations_metadata,
|
||||
@@ -222,6 +224,20 @@ class Test_AWS_Organizations:
|
||||
assert tags == {}
|
||||
assert ou_metadata == {}
|
||||
|
||||
def test_get_organizations_metadata_uses_user_agent_extra(self):
|
||||
real_session = boto3.Session()
|
||||
real_session._session.set_default_client_config(
|
||||
AwsProvider.set_session_config(None)
|
||||
)
|
||||
wrapper = MagicMock(wraps=real_session)
|
||||
|
||||
get_organizations_metadata("123456789012", wrapper)
|
||||
|
||||
wrapper.client.assert_called_once()
|
||||
default_config = real_session._session.get_default_client_config()
|
||||
assert default_config is not None
|
||||
assert BOTO3_USER_AGENT_EXTRA in default_config.user_agent_extra
|
||||
|
||||
def test_parse_organizations_metadata_with_empty_ou_metadata(self):
|
||||
tags = {"Tags": []}
|
||||
metadata = {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from mock import patch
|
||||
|
||||
from prowler.providers.aws.config import BOTO3_USER_AGENT_EXTRA
|
||||
from prowler.providers.aws.lib.service.service import AWSService
|
||||
from tests.providers.aws.utils import (
|
||||
AWS_ACCOUNT_ARN,
|
||||
@@ -189,6 +190,15 @@ class TestAWSService:
|
||||
== f"arn:{service.audited_partition}:{service_name}::{AWS_ACCOUNT_NUMBER}:bucket/unknown"
|
||||
)
|
||||
|
||||
def test_AWSService_clients_carry_user_agent_extra(self):
|
||||
provider = set_mocked_aws_provider()
|
||||
|
||||
service = AWSService("s3", provider)
|
||||
ad_hoc_client = service.session.client("ec2", AWS_REGION_US_EAST_1)
|
||||
|
||||
assert BOTO3_USER_AGENT_EXTRA in service.client._client_config.user_agent_extra
|
||||
assert BOTO3_USER_AGENT_EXTRA in ad_hoc_client._client_config.user_agent_extra
|
||||
|
||||
def test_AWSService_get_unknown_arn_resource_type_set_region(self):
|
||||
service_name = "s3"
|
||||
provider = set_mocked_aws_provider()
|
||||
|
||||
@@ -45,11 +45,12 @@ def mock_make_api_call(self, operation_name, kwarg):
|
||||
elif operation_name == "ListBuildsForProject":
|
||||
return {"ids": [build_id]}
|
||||
elif operation_name == "BatchGetBuilds":
|
||||
return {"builds": [{"endTime": last_invoked_time}]}
|
||||
return {"builds": [{"id": build_id, "endTime": last_invoked_time}]}
|
||||
elif operation_name == "BatchGetProjects":
|
||||
return {
|
||||
"projects": [
|
||||
{
|
||||
"arn": project_arn,
|
||||
"source": {
|
||||
"type": source_type,
|
||||
"location": bitbucket_url,
|
||||
@@ -230,3 +231,97 @@ class Test_Codebuild_Service:
|
||||
assert (
|
||||
codebuild.report_groups[report_group_arn].tags[0]["value"] == project_name
|
||||
)
|
||||
|
||||
|
||||
# Module-level state and helpers used by the chunking/out-of-order test below.
|
||||
# Kept at module level so the API-call mock is a plain function rather than a
|
||||
# closure defined inside the test method.
|
||||
TOTAL_PROJECTS = 150
|
||||
many_project_names = [f"project-{i}" for i in range(TOTAL_PROJECTS)]
|
||||
many_project_arns = [
|
||||
f"arn:{AWS_COMMERCIAL_PARTITION}:codebuild:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:project/{name}"
|
||||
for name in many_project_names
|
||||
]
|
||||
many_build_ids_for = {name: f"{name}:build-id" for name in many_project_names}
|
||||
many_end_times_for = {
|
||||
name: datetime.now() - timedelta(days=i)
|
||||
for i, name in enumerate(many_project_names)
|
||||
}
|
||||
many_name_by_build_id = {v: k for k, v in many_build_ids_for.items()}
|
||||
many_batch_call_sizes = {"BatchGetProjects": [], "BatchGetBuilds": []}
|
||||
|
||||
|
||||
def mock_make_api_call_many_projects(self, operation_name, kwarg):
|
||||
if operation_name == "ListProjects":
|
||||
return {"projects": many_project_names}
|
||||
if operation_name == "ListBuildsForProject":
|
||||
return {"ids": [many_build_ids_for[kwarg["projectName"]]]}
|
||||
if operation_name == "BatchGetBuilds":
|
||||
ids = kwarg["ids"]
|
||||
many_batch_call_sizes["BatchGetBuilds"].append(len(ids))
|
||||
# Reverse the response order to verify id->project mapping does not
|
||||
# depend on response ordering.
|
||||
builds = [
|
||||
{"id": bid, "endTime": many_end_times_for[many_name_by_build_id[bid]]}
|
||||
for bid in reversed(ids)
|
||||
]
|
||||
return {"builds": builds}
|
||||
if operation_name == "BatchGetProjects":
|
||||
names = kwarg["names"]
|
||||
many_batch_call_sizes["BatchGetProjects"].append(len(names))
|
||||
# Reverse the response order to verify arn->project mapping does not
|
||||
# depend on response ordering.
|
||||
projects = [
|
||||
{
|
||||
"arn": f"arn:{AWS_COMMERCIAL_PARTITION}:codebuild:{AWS_REGION_EU_WEST_1}:{AWS_ACCOUNT_NUMBER}:project/{name}",
|
||||
"source": {"type": "NO_SOURCE"},
|
||||
"logsConfig": {},
|
||||
"tags": [],
|
||||
"projectVisibility": "PRIVATE",
|
||||
}
|
||||
for name in reversed(names)
|
||||
]
|
||||
return {"projects": projects}
|
||||
if operation_name == "ListReportGroups":
|
||||
return {"reportGroups": []}
|
||||
return make_api_call(self, operation_name, kwarg)
|
||||
|
||||
|
||||
class Test_Codebuild_Service_Batching:
|
||||
@patch(
|
||||
"botocore.client.BaseClient._make_api_call",
|
||||
new=mock_make_api_call_many_projects,
|
||||
)
|
||||
@patch(
|
||||
"prowler.providers.aws.aws_provider.AwsProvider.generate_regional_clients",
|
||||
new=mock_generate_regional_clients,
|
||||
)
|
||||
@mock_aws
|
||||
def test_codebuild_batches_chunks_over_100_projects_and_maps_out_of_order_responses(
|
||||
self,
|
||||
):
|
||||
"""Verify _batch_get_projects/_batch_get_builds chunk in groups of 100
|
||||
and correctly map out-of-order batch responses back to the right
|
||||
project using `arn`/`id`.
|
||||
"""
|
||||
# Reset the per-test recorder (module-level state survives across runs).
|
||||
many_batch_call_sizes["BatchGetProjects"].clear()
|
||||
many_batch_call_sizes["BatchGetBuilds"].clear()
|
||||
|
||||
codebuild = Codebuild(set_mocked_aws_provider([AWS_REGION_EU_WEST_1]))
|
||||
|
||||
# Verify chunking: 150 items -> two batches of 100 and 50.
|
||||
assert sorted(many_batch_call_sizes["BatchGetProjects"]) == [50, 100]
|
||||
assert sorted(many_batch_call_sizes["BatchGetBuilds"]) == [50, 100]
|
||||
|
||||
# Verify all projects were tracked.
|
||||
assert len(codebuild.projects) == TOTAL_PROJECTS
|
||||
|
||||
# Verify out-of-order responses were correctly mapped back to the
|
||||
# right project by `arn` (projects) and `id` (builds).
|
||||
for name, arn in zip(many_project_names, many_project_arns):
|
||||
project = codebuild.projects[arn]
|
||||
assert project.name == name
|
||||
assert project.project_visibility == "PRIVATE"
|
||||
assert project.last_build == Build(id=many_build_ids_for[name])
|
||||
assert project.last_invoked_time == many_end_times_for[name]
|
||||
|
||||
@@ -2,7 +2,6 @@ from argparse import Namespace
|
||||
from json import dumps
|
||||
|
||||
from boto3 import client, session
|
||||
from botocore.config import Config
|
||||
from moto import mock_aws
|
||||
|
||||
from prowler.config.config import (
|
||||
@@ -133,10 +132,11 @@ def set_mocked_aws_provider(
|
||||
provider = AwsProvider()
|
||||
|
||||
# Mock Session
|
||||
provider._session.session_config = None
|
||||
session_config = AwsProvider.set_session_config(None)
|
||||
provider._session.session_config = session_config
|
||||
provider._session.original_session = original_session
|
||||
provider._session.current_session = audit_session
|
||||
provider._session.session_config = Config()
|
||||
audit_session._session.set_default_client_config(session_config)
|
||||
# Mock Identity
|
||||
provider._identity.account = audited_account
|
||||
provider._identity.account_arn = audited_account_arn
|
||||
|
||||
Reference in New Issue
Block a user