From a024ab31a03220b6a484845ff9a5c5c1751a82ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20Mart=C3=ADn?= Date: Thu, 17 Oct 2024 09:29:02 +0200 Subject: [PATCH] feat(scan): add arguments (#5427) --- prowler/exceptions/exceptions.py | 10 +- prowler/lib/check/checks_loader.py | 2 +- prowler/lib/scan/exceptions/exceptions.py | 76 ++++++++++ prowler/lib/scan/scan.py | 132 +++++++++++++++--- .../providers/aws/exceptions/exceptions.py | 2 +- .../providers/azure/exceptions/exceptions.py | 2 +- .../providers/azure/lib/exception/__init__.py | 0 .../azure/lib/exception/exception.py | 11 -- .../providers/gcp/exceptions/exceptions.py | 2 +- .../kubernetes/exceptions/exceptions.py | 2 +- tests/lib/scan/scan_test.py | 105 +++++++++++++- 11 files changed, 296 insertions(+), 48 deletions(-) create mode 100644 prowler/lib/scan/exceptions/exceptions.py delete mode 100644 prowler/providers/azure/lib/exception/__init__.py delete mode 100644 prowler/providers/azure/lib/exception/exception.py diff --git a/prowler/exceptions/exceptions.py b/prowler/exceptions/exceptions.py index 352733afc9..ea69a2655f 100644 --- a/prowler/exceptions/exceptions.py +++ b/prowler/exceptions/exceptions.py @@ -9,14 +9,14 @@ class ProwlerException(Exception): } def __init__( - self, code, provider=None, file=None, original_exception=None, error_info=None + self, code, source=None, file=None, original_exception=None, error_info=None ): """ Initialize the ProwlerException class. Args: code (int): The error code. - provider (str): The provider name. + source (str): The source name. This can be the provider name, module name, service name, etc. file (str): The file name. original_exception (Exception): The original exception. error_info (dict): The error information. @@ -28,7 +28,7 @@ class ProwlerException(Exception): >>> [1901] Unexpected error occurred. - Exception: Error occurred. """ self.code = code - self.provider = provider + self.source = source self.file = file if error_info is None: error_info = self.ERROR_CODES.get((code, self.__class__.__name__)) @@ -52,5 +52,5 @@ class ProwlerException(Exception): class UnexpectedError(ProwlerException): - def __init__(self, provider, file, original_exception=None): - super().__init__(1901, provider, file, original_exception) + def __init__(self, source, file, original_exception=None): + super().__init__(1901, source, file, original_exception) diff --git a/prowler/lib/check/checks_loader.py b/prowler/lib/check/checks_loader.py index b8496c2b2b..acc97cb059 100644 --- a/prowler/lib/check/checks_loader.py +++ b/prowler/lib/check/checks_loader.py @@ -102,7 +102,7 @@ def load_checks_to_execute( checks_to_execute.add(check_name) # Only execute threat detection checks if threat-detection category is set - if "threat-detection" not in categories: + if categories != [] and "threat-detection" not in categories: for threat_detection_check in check_categories.get("threat-detection", []): checks_to_execute.discard(threat_detection_check) diff --git a/prowler/lib/scan/exceptions/exceptions.py b/prowler/lib/scan/exceptions/exceptions.py new file mode 100644 index 0000000000..96b11acd43 --- /dev/null +++ b/prowler/lib/scan/exceptions/exceptions.py @@ -0,0 +1,76 @@ +from prowler.exceptions.exceptions import ProwlerException + + +class ScanBaseException(ProwlerException): + """Base class for Scan errors.""" + + SCAN_ERROR_CODES = { + (2000, "ScanInvalidSeverityError"): { + "message": "Invalid severity level provided.", + "remediation": "Please provide a valid severity level. Valid severities are: critical, high, medium, low, informational.", + }, + (2001, "ScanInvalidCheckError"): { + "message": "Invalid check provided.", + "remediation": "Please provide a valid check name.", + }, + (2002, "ScanInvalidServiceError"): { + "message": "Invalid service provided.", + "remediation": "Please provide a valid service name.", + }, + (2003, "ScanInvalidComplianceFrameworkError"): { + "message": "Invalid compliance framework provided.", + "remediation": "Please provide a valid compliance framework name for the chosen provider.", + }, + (2004, "ScanInvalidCategoryError"): { + "message": "Invalid category provided.", + "remediation": "Please provide a valid category name.", + }, + } + + def __init__(self, code, file=None, original_exception=None, message=None): + module = "Scan" + error_info = self.SCAN_ERROR_CODES.get((code, self.__class__.__name__)) + if message: + error_info["message"] = message + super().__init__( + code=code, + source=module, + file=file, + original_exception=original_exception, + error_info=error_info, + ) + + +class ScanInvalidSeverityError(ScanBaseException): + def __init__(self, file=None, original_exception=None, message=None): + super().__init__( + 2000, file=file, original_exception=original_exception, message=message + ) + + +class ScanInvalidCheckError(ScanBaseException): + def __init__(self, file=None, original_exception=None, message=None): + super().__init__( + 2001, file=file, original_exception=original_exception, message=message + ) + + +class ScanInvalidServiceError(ScanBaseException): + def __init__(self, file=None, original_exception=None, message=None): + super().__init__( + 2002, file=file, original_exception=original_exception, message=message + ) + + +class ScanInvalidComplianceFrameworkError(ScanBaseException): + def __init__(self, file=None, original_exception=None, message=None): + super().__init__( + 2003, file=file, original_exception=original_exception, message=message + ) + + +class ScanInvalidCategoryError(ScanBaseException): + def __init__(self, file=None, original_exception=None, message=None): + super().__init__( + 2004, file=file, original_exception=original_exception, message=message + ) diff --git a/prowler/lib/scan/scan.py b/prowler/lib/scan/scan.py index 27d2005aa7..05b8462a39 100644 --- a/prowler/lib/scan/scan.py +++ b/prowler/lib/scan/scan.py @@ -1,10 +1,25 @@ import datetime from typing import Generator -from prowler.lib.check.check import execute, import_check, update_audit_metadata -from prowler.lib.check.utils import recover_checks_from_provider +from prowler.lib.check.check import ( + execute, + import_check, + list_services, + update_audit_metadata, +) +from prowler.lib.check.checks_loader import load_checks_to_execute +from prowler.lib.check.compliance import update_checks_metadata_with_compliance +from prowler.lib.check.compliance_models import Compliance +from prowler.lib.check.models import CheckMetadata from prowler.lib.logger import logger -from prowler.lib.outputs.finding import Finding +from prowler.lib.outputs.finding import Finding, Severity +from prowler.lib.scan.exceptions.exceptions import ( + ScanInvalidCategoryError, + ScanInvalidCheckError, + ScanInvalidComplianceFrameworkError, + ScanInvalidServiceError, + ScanInvalidSeverityError, +) from prowler.providers.common.models import Audit_Metadata from prowler.providers.common.provider import Provider @@ -22,35 +37,106 @@ class Scan: _findings: list = [] _duration: int = 0 - def __init__(self, provider: Provider, checks_to_execute: list[str] = None): + def __init__( + self, + provider: Provider, + checks: list[str] = None, + services: list[str] = None, + compliances: list[str] = None, + categories: list[str] = None, + severities: list[str] = None, + ): """ Scan is the class that executes the checks and yields the progress and the findings. Params: provider: Provider -> The provider to scan - checks_to_execute: list[str] -> The checks to execute + checks: list[str] -> The checks to execute + services: list[str] -> The services to scan + compliances: list[str] -> The compliances to check + categories: list[str] -> The categories of the checks + severities: list[str] -> The severities of the checks + + Raises: + ScanInvalidCheckError: If the check does not exist in the provider or is from another provider. + ScanInvalidServiceError: If the service does not exist in the provider. + ScanInvalidComplianceFrameworkError: If the compliance framework does not exist in the provider. + ScanInvalidCategoryError: If the category does not exist in the provider. + ScanInvalidSeverityError: If the severity does not exist in the provider. """ self._provider = provider - # Remove duplicated checks and sort them - self._checks_to_execute = ( - sorted(list(set(checks_to_execute))) - if checks_to_execute - else sorted( - [check[0] for check in recover_checks_from_provider(provider.type)] - ) + + # Load bulk compliance frameworks + bulk_compliance_frameworks = Compliance.get_bulk(provider.type) + + # Get bulk checks metadata for the provider + bulk_checks_metadata = CheckMetadata.get_bulk(provider.type) + # Complete checks metadata with the compliance framework specification + bulk_checks_metadata = update_checks_metadata_with_compliance( + bulk_compliance_frameworks, bulk_checks_metadata ) - # TODO This should be done depending on the scan args (future feature) - # Discard threat detection checks - if "cloudtrail_threat_detection_enumeration" in self._checks_to_execute: - self._checks_to_execute.remove("cloudtrail_threat_detection_enumeration") - if ( - "cloudtrail_threat_detection_privilege_escalation" - in self._checks_to_execute - ): - self._checks_to_execute.remove( - "cloudtrail_threat_detection_privilege_escalation" + # Create a list of valid categories + valid_categories = set() + for check, metadata in bulk_checks_metadata.items(): + for category in metadata.Categories: + if category not in valid_categories: + valid_categories.add(category) + + # Validate checks + if checks: + for check in checks: + if check not in bulk_checks_metadata.keys(): + raise ScanInvalidCheckError(f"Invalid check provided: {check}.") + + # Validate services + if services: + for service in services: + if service not in list_services(provider.type): + raise ScanInvalidServiceError( + f"Invalid service provided: {service}." + ) + + # Validate compliances + if compliances: + for compliance in compliances: + if compliance not in bulk_compliance_frameworks.keys(): + raise ScanInvalidComplianceFrameworkError( + f"Invalid compliance provided: {compliance}." + ) + + # Validate categories + if categories: + for category in categories: + if category not in valid_categories: + raise ScanInvalidCategoryError( + f"Invalid category provided: {category}." + ) + + # Validate severity + if severities: + for severity in severities: + try: + Severity(severity) + except ValueError: + raise ScanInvalidSeverityError( + f"Invalid severity provided: {severity}." + ) + + # Load checks to execute + self._checks_to_execute = sorted( + load_checks_to_execute( + bulk_checks_metadata=bulk_checks_metadata, + bulk_compliance_frameworks=bulk_compliance_frameworks, + check_list=checks, + service_list=services, + compliance_frameworks=compliances, + categories=categories, + severities=severities, + provider=provider.type, + checks_file=None, ) + ) self._number_of_checks_to_execute = len(self._checks_to_execute) @@ -63,7 +149,7 @@ class Scan: self._service_checks_completed = service_checks_completed @property - def checks_to_execute(self) -> set[str]: + def checks_to_execute(self) -> list[str]: return self._checks_to_execute @property diff --git a/prowler/providers/aws/exceptions/exceptions.py b/prowler/providers/aws/exceptions/exceptions.py index 579dd88e97..9cca1855be 100644 --- a/prowler/providers/aws/exceptions/exceptions.py +++ b/prowler/providers/aws/exceptions/exceptions.py @@ -77,7 +77,7 @@ class AWSBaseException(ProwlerException): error_info["message"] = message super().__init__( code, - provider="AWS", + source="AWS", file=file, original_exception=original_exception, error_info=error_info, diff --git a/prowler/providers/azure/exceptions/exceptions.py b/prowler/providers/azure/exceptions/exceptions.py index 163fbb424d..9b4d0e443b 100644 --- a/prowler/providers/azure/exceptions/exceptions.py +++ b/prowler/providers/azure/exceptions/exceptions.py @@ -110,7 +110,7 @@ class AzureBaseException(ProwlerException): error_info["message"] = message super().__init__( code=code, - provider=provider, + source=provider, file=file, original_exception=original_exception, error_info=error_info, diff --git a/prowler/providers/azure/lib/exception/__init__.py b/prowler/providers/azure/lib/exception/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/prowler/providers/azure/lib/exception/exception.py b/prowler/providers/azure/lib/exception/exception.py deleted file mode 100644 index 51fe95d001..0000000000 --- a/prowler/providers/azure/lib/exception/exception.py +++ /dev/null @@ -1,11 +0,0 @@ -class AzureException(Exception): - """ - Exception raised when dealing with Azure Provider/Azure audit info instance - - Attributes: - message -- message to be displayed - """ - - def __init__(self, message): - self.message = message - super().__init__(self.message) diff --git a/prowler/providers/gcp/exceptions/exceptions.py b/prowler/providers/gcp/exceptions/exceptions.py index 5fe30484f6..c4dc7883be 100644 --- a/prowler/providers/gcp/exceptions/exceptions.py +++ b/prowler/providers/gcp/exceptions/exceptions.py @@ -50,7 +50,7 @@ class GCPBaseException(ProwlerException): error_info["message"] = message super().__init__( code=code, - provider=provider, + source=provider, file=file, original_exception=original_exception, error_info=error_info, diff --git a/prowler/providers/kubernetes/exceptions/exceptions.py b/prowler/providers/kubernetes/exceptions/exceptions.py index 46122a6406..57cfda39c6 100644 --- a/prowler/providers/kubernetes/exceptions/exceptions.py +++ b/prowler/providers/kubernetes/exceptions/exceptions.py @@ -44,7 +44,7 @@ class KubernetesBaseException(ProwlerException): error_info["message"] = message super().__init__( code=code, - provider=provider, + source=provider, file=file, original_exception=original_exception, error_info=error_info, diff --git a/tests/lib/scan/scan_test.py b/tests/lib/scan/scan_test.py index c737039917..079ac0bb13 100644 --- a/tests/lib/scan/scan_test.py +++ b/tests/lib/scan/scan_test.py @@ -5,6 +5,13 @@ from unittest import mock import pytest from mock import MagicMock, patch +from prowler.lib.scan.exceptions.exceptions import ( + ScanInvalidCategoryError, + ScanInvalidCheckError, + ScanInvalidComplianceFrameworkError, + ScanInvalidServiceError, + ScanInvalidSeverityError, +) from prowler.lib.scan.scan import Scan, get_service_checks_to_execute from tests.lib.outputs.fixtures.fixtures import generate_finding_output from tests.providers.aws.utils import set_mocked_aws_provider @@ -88,6 +95,31 @@ def mock_list_modules(): yield mock_list_mod +@pytest.fixture +def mock_recover_checks_from_provider(): + with mock.patch( + "prowler.lib.check.models.recover_checks_from_provider", autospec=True + ) as mock_recover: + mock_recover.return_value = [ + ( + "accessanalyzer_enabled", + "/prowler/providers/aws/services/accessanalyzer/accessanalyzer_enabled", + ) + ] + yield mock_recover + + +@pytest.fixture +def mock_load_check_metadata(): + with mock.patch( + "prowler.lib.check.models.load_check_metadata", autospec=True + ) as mock_load: + mock_metadata = MagicMock() + mock_metadata.CheckID = "accessanalyzer_enabled" + mock_load.return_value = mock_metadata + yield mock_load + + class TestScan: def test_init(mock_provider): checks_to_execute = { @@ -151,7 +183,8 @@ class TestScan: "cognito_user_pool_waf_acl_attached", "config_recorder_all_regions_enabled", } - scan = Scan(mock_provider, checks_to_execute) + mock_provider.type = "aws" + scan = Scan(mock_provider, checks=checks_to_execute) assert scan.provider == mock_provider # Check that the checks to execute are sorted and without duplicates @@ -224,11 +257,19 @@ class TestScan: assert scan.get_completed_services() == set() assert scan.get_completed_checks() == set() - def test_init_with_no_checks(mock_provider, mock_list_modules): + def test_init_with_no_checks( + mock_provider, + mock_list_modules, + mock_recover_checks_from_provider, + mock_load_check_metadata, + ): checks_to_execute = set() mock_provider.type = "aws" - scan = Scan(mock_provider, checks_to_execute) + scan = Scan(mock_provider, checks=checks_to_execute) + mock_list_modules.assert_called_once_with("aws", None) + mock_load_check_metadata.assert_called_once() + mock_recover_checks_from_provider.assert_called_once_with("aws") assert scan.provider == mock_provider assert scan.checks_to_execute == ["accessanalyzer_enabled"] @@ -247,12 +288,15 @@ class TestScan: mock_execute, mock_logger, mock_generate_output, + mock_recover_checks_from_provider, + mock_load_check_metadata, ): mock_check_class = MagicMock() mock_check_instance = mock_check_class.return_value mock_check_instance.Provider = "aws" mock_check_instance.CheckID = "accessanalyzer_enabled" mock_check_instance.CheckTitle = "Check if IAM Access Analyzer is enabled" + mock_check_instance.Categories = [] mock_import_module.return_value = MagicMock( accessanalyzer_enabled=mock_check_class @@ -262,7 +306,9 @@ class TestScan: custom_checks_metadata = {} mock_global_provider.type = "aws" - scan = Scan(mock_global_provider, checks_to_execute) + scan = Scan(mock_global_provider, checks=checks_to_execute) + mock_load_check_metadata.assert_called_once() + mock_recover_checks_from_provider.assert_called_once_with("aws") results = list(scan.scan(custom_checks_metadata)) assert mock_generate_output.call_count == 1 * len(mock_execute.side_effect()) @@ -279,3 +325,54 @@ class TestScan: } assert scan.findings == mock_execute.side_effect() mock_logger.error.assert_not_called() + + def test_init_invalid_severity( + mock_provider, + ): + checks_to_execute = set() + mock_provider.type = "aws" + + with pytest.raises(ScanInvalidSeverityError): + Scan(mock_provider, checks=checks_to_execute, severities=["invalid"]) + + def test_init_invalid_check( + mock_provider, + ): + checks_to_execute = ["invalid_check"] + mock_provider.type = "aws" + + with pytest.raises(ScanInvalidCheckError): + Scan(mock_provider, checks=checks_to_execute) + + def test_init_invalid_service( + mock_provider, + ): + checks_to_execute = set() + mock_provider.type = "aws" + + with pytest.raises(ScanInvalidServiceError): + Scan(mock_provider, checks=checks_to_execute, services=["invalid_service"]) + + def test_init_invalid_compliance_framework( + mock_provider, + ): + checks_to_execute = set() + mock_provider.type = "aws" + + with pytest.raises(ScanInvalidComplianceFrameworkError): + Scan( + mock_provider, + checks=checks_to_execute, + compliances=["invalid_framework"], + ) + + def test_init_invalid_category( + mock_provider, + ): + checks_to_execute = set() + mock_provider.type = "aws" + + with pytest.raises(ScanInvalidCategoryError): + Scan( + mock_provider, checks=checks_to_execute, categories=["invalid_category"] + )