mirror of
https://github.com/prowler-cloud/prowler.git
synced 2026-07-04 19:21:51 +00:00
feat(scan): add arguments (#5427)
This commit is contained in:
@@ -9,14 +9,14 @@ class ProwlerException(Exception):
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, code, provider=None, file=None, original_exception=None, error_info=None
|
||||
self, code, source=None, file=None, original_exception=None, error_info=None
|
||||
):
|
||||
"""
|
||||
Initialize the ProwlerException class.
|
||||
|
||||
Args:
|
||||
code (int): The error code.
|
||||
provider (str): The provider name.
|
||||
source (str): The source name. This can be the provider name, module name, service name, etc.
|
||||
file (str): The file name.
|
||||
original_exception (Exception): The original exception.
|
||||
error_info (dict): The error information.
|
||||
@@ -28,7 +28,7 @@ class ProwlerException(Exception):
|
||||
>>> [1901] Unexpected error occurred. - Exception: Error occurred.
|
||||
"""
|
||||
self.code = code
|
||||
self.provider = provider
|
||||
self.source = source
|
||||
self.file = file
|
||||
if error_info is None:
|
||||
error_info = self.ERROR_CODES.get((code, self.__class__.__name__))
|
||||
@@ -52,5 +52,5 @@ class ProwlerException(Exception):
|
||||
|
||||
|
||||
class UnexpectedError(ProwlerException):
|
||||
def __init__(self, provider, file, original_exception=None):
|
||||
super().__init__(1901, provider, file, original_exception)
|
||||
def __init__(self, source, file, original_exception=None):
|
||||
super().__init__(1901, source, file, original_exception)
|
||||
|
||||
@@ -102,7 +102,7 @@ def load_checks_to_execute(
|
||||
checks_to_execute.add(check_name)
|
||||
|
||||
# Only execute threat detection checks if threat-detection category is set
|
||||
if "threat-detection" not in categories:
|
||||
if categories != [] and "threat-detection" not in categories:
|
||||
for threat_detection_check in check_categories.get("threat-detection", []):
|
||||
checks_to_execute.discard(threat_detection_check)
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
from prowler.exceptions.exceptions import ProwlerException
|
||||
|
||||
|
||||
class ScanBaseException(ProwlerException):
|
||||
"""Base class for Scan errors."""
|
||||
|
||||
SCAN_ERROR_CODES = {
|
||||
(2000, "ScanInvalidSeverityError"): {
|
||||
"message": "Invalid severity level provided.",
|
||||
"remediation": "Please provide a valid severity level. Valid severities are: critical, high, medium, low, informational.",
|
||||
},
|
||||
(2001, "ScanInvalidCheckError"): {
|
||||
"message": "Invalid check provided.",
|
||||
"remediation": "Please provide a valid check name.",
|
||||
},
|
||||
(2002, "ScanInvalidServiceError"): {
|
||||
"message": "Invalid service provided.",
|
||||
"remediation": "Please provide a valid service name.",
|
||||
},
|
||||
(2003, "ScanInvalidComplianceFrameworkError"): {
|
||||
"message": "Invalid compliance framework provided.",
|
||||
"remediation": "Please provide a valid compliance framework name for the chosen provider.",
|
||||
},
|
||||
(2004, "ScanInvalidCategoryError"): {
|
||||
"message": "Invalid category provided.",
|
||||
"remediation": "Please provide a valid category name.",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, code, file=None, original_exception=None, message=None):
|
||||
module = "Scan"
|
||||
error_info = self.SCAN_ERROR_CODES.get((code, self.__class__.__name__))
|
||||
if message:
|
||||
error_info["message"] = message
|
||||
super().__init__(
|
||||
code=code,
|
||||
source=module,
|
||||
file=file,
|
||||
original_exception=original_exception,
|
||||
error_info=error_info,
|
||||
)
|
||||
|
||||
|
||||
class ScanInvalidSeverityError(ScanBaseException):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
2000, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
|
||||
class ScanInvalidCheckError(ScanBaseException):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
2001, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
|
||||
class ScanInvalidServiceError(ScanBaseException):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
2002, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
|
||||
class ScanInvalidComplianceFrameworkError(ScanBaseException):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
2003, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
|
||||
|
||||
class ScanInvalidCategoryError(ScanBaseException):
|
||||
def __init__(self, file=None, original_exception=None, message=None):
|
||||
super().__init__(
|
||||
2004, file=file, original_exception=original_exception, message=message
|
||||
)
|
||||
+109
-23
@@ -1,10 +1,25 @@
|
||||
import datetime
|
||||
from typing import Generator
|
||||
|
||||
from prowler.lib.check.check import execute, import_check, update_audit_metadata
|
||||
from prowler.lib.check.utils import recover_checks_from_provider
|
||||
from prowler.lib.check.check import (
|
||||
execute,
|
||||
import_check,
|
||||
list_services,
|
||||
update_audit_metadata,
|
||||
)
|
||||
from prowler.lib.check.checks_loader import load_checks_to_execute
|
||||
from prowler.lib.check.compliance import update_checks_metadata_with_compliance
|
||||
from prowler.lib.check.compliance_models import Compliance
|
||||
from prowler.lib.check.models import CheckMetadata
|
||||
from prowler.lib.logger import logger
|
||||
from prowler.lib.outputs.finding import Finding
|
||||
from prowler.lib.outputs.finding import Finding, Severity
|
||||
from prowler.lib.scan.exceptions.exceptions import (
|
||||
ScanInvalidCategoryError,
|
||||
ScanInvalidCheckError,
|
||||
ScanInvalidComplianceFrameworkError,
|
||||
ScanInvalidServiceError,
|
||||
ScanInvalidSeverityError,
|
||||
)
|
||||
from prowler.providers.common.models import Audit_Metadata
|
||||
from prowler.providers.common.provider import Provider
|
||||
|
||||
@@ -22,35 +37,106 @@ class Scan:
|
||||
_findings: list = []
|
||||
_duration: int = 0
|
||||
|
||||
def __init__(self, provider: Provider, checks_to_execute: list[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
provider: Provider,
|
||||
checks: list[str] = None,
|
||||
services: list[str] = None,
|
||||
compliances: list[str] = None,
|
||||
categories: list[str] = None,
|
||||
severities: list[str] = None,
|
||||
):
|
||||
"""
|
||||
Scan is the class that executes the checks and yields the progress and the findings.
|
||||
|
||||
Params:
|
||||
provider: Provider -> The provider to scan
|
||||
checks_to_execute: list[str] -> The checks to execute
|
||||
checks: list[str] -> The checks to execute
|
||||
services: list[str] -> The services to scan
|
||||
compliances: list[str] -> The compliances to check
|
||||
categories: list[str] -> The categories of the checks
|
||||
severities: list[str] -> The severities of the checks
|
||||
|
||||
Raises:
|
||||
ScanInvalidCheckError: If the check does not exist in the provider or is from another provider.
|
||||
ScanInvalidServiceError: If the service does not exist in the provider.
|
||||
ScanInvalidComplianceFrameworkError: If the compliance framework does not exist in the provider.
|
||||
ScanInvalidCategoryError: If the category does not exist in the provider.
|
||||
ScanInvalidSeverityError: If the severity does not exist in the provider.
|
||||
"""
|
||||
self._provider = provider
|
||||
# Remove duplicated checks and sort them
|
||||
self._checks_to_execute = (
|
||||
sorted(list(set(checks_to_execute)))
|
||||
if checks_to_execute
|
||||
else sorted(
|
||||
[check[0] for check in recover_checks_from_provider(provider.type)]
|
||||
)
|
||||
|
||||
# Load bulk compliance frameworks
|
||||
bulk_compliance_frameworks = Compliance.get_bulk(provider.type)
|
||||
|
||||
# Get bulk checks metadata for the provider
|
||||
bulk_checks_metadata = CheckMetadata.get_bulk(provider.type)
|
||||
# Complete checks metadata with the compliance framework specification
|
||||
bulk_checks_metadata = update_checks_metadata_with_compliance(
|
||||
bulk_compliance_frameworks, bulk_checks_metadata
|
||||
)
|
||||
|
||||
# TODO This should be done depending on the scan args (future feature)
|
||||
# Discard threat detection checks
|
||||
if "cloudtrail_threat_detection_enumeration" in self._checks_to_execute:
|
||||
self._checks_to_execute.remove("cloudtrail_threat_detection_enumeration")
|
||||
if (
|
||||
"cloudtrail_threat_detection_privilege_escalation"
|
||||
in self._checks_to_execute
|
||||
):
|
||||
self._checks_to_execute.remove(
|
||||
"cloudtrail_threat_detection_privilege_escalation"
|
||||
# Create a list of valid categories
|
||||
valid_categories = set()
|
||||
for check, metadata in bulk_checks_metadata.items():
|
||||
for category in metadata.Categories:
|
||||
if category not in valid_categories:
|
||||
valid_categories.add(category)
|
||||
|
||||
# Validate checks
|
||||
if checks:
|
||||
for check in checks:
|
||||
if check not in bulk_checks_metadata.keys():
|
||||
raise ScanInvalidCheckError(f"Invalid check provided: {check}.")
|
||||
|
||||
# Validate services
|
||||
if services:
|
||||
for service in services:
|
||||
if service not in list_services(provider.type):
|
||||
raise ScanInvalidServiceError(
|
||||
f"Invalid service provided: {service}."
|
||||
)
|
||||
|
||||
# Validate compliances
|
||||
if compliances:
|
||||
for compliance in compliances:
|
||||
if compliance not in bulk_compliance_frameworks.keys():
|
||||
raise ScanInvalidComplianceFrameworkError(
|
||||
f"Invalid compliance provided: {compliance}."
|
||||
)
|
||||
|
||||
# Validate categories
|
||||
if categories:
|
||||
for category in categories:
|
||||
if category not in valid_categories:
|
||||
raise ScanInvalidCategoryError(
|
||||
f"Invalid category provided: {category}."
|
||||
)
|
||||
|
||||
# Validate severity
|
||||
if severities:
|
||||
for severity in severities:
|
||||
try:
|
||||
Severity(severity)
|
||||
except ValueError:
|
||||
raise ScanInvalidSeverityError(
|
||||
f"Invalid severity provided: {severity}."
|
||||
)
|
||||
|
||||
# Load checks to execute
|
||||
self._checks_to_execute = sorted(
|
||||
load_checks_to_execute(
|
||||
bulk_checks_metadata=bulk_checks_metadata,
|
||||
bulk_compliance_frameworks=bulk_compliance_frameworks,
|
||||
check_list=checks,
|
||||
service_list=services,
|
||||
compliance_frameworks=compliances,
|
||||
categories=categories,
|
||||
severities=severities,
|
||||
provider=provider.type,
|
||||
checks_file=None,
|
||||
)
|
||||
)
|
||||
|
||||
self._number_of_checks_to_execute = len(self._checks_to_execute)
|
||||
|
||||
@@ -63,7 +149,7 @@ class Scan:
|
||||
self._service_checks_completed = service_checks_completed
|
||||
|
||||
@property
|
||||
def checks_to_execute(self) -> set[str]:
|
||||
def checks_to_execute(self) -> list[str]:
|
||||
return self._checks_to_execute
|
||||
|
||||
@property
|
||||
|
||||
@@ -77,7 +77,7 @@ class AWSBaseException(ProwlerException):
|
||||
error_info["message"] = message
|
||||
super().__init__(
|
||||
code,
|
||||
provider="AWS",
|
||||
source="AWS",
|
||||
file=file,
|
||||
original_exception=original_exception,
|
||||
error_info=error_info,
|
||||
|
||||
@@ -110,7 +110,7 @@ class AzureBaseException(ProwlerException):
|
||||
error_info["message"] = message
|
||||
super().__init__(
|
||||
code=code,
|
||||
provider=provider,
|
||||
source=provider,
|
||||
file=file,
|
||||
original_exception=original_exception,
|
||||
error_info=error_info,
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
class AzureException(Exception):
|
||||
"""
|
||||
Exception raised when dealing with Azure Provider/Azure audit info instance
|
||||
|
||||
Attributes:
|
||||
message -- message to be displayed
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -50,7 +50,7 @@ class GCPBaseException(ProwlerException):
|
||||
error_info["message"] = message
|
||||
super().__init__(
|
||||
code=code,
|
||||
provider=provider,
|
||||
source=provider,
|
||||
file=file,
|
||||
original_exception=original_exception,
|
||||
error_info=error_info,
|
||||
|
||||
@@ -44,7 +44,7 @@ class KubernetesBaseException(ProwlerException):
|
||||
error_info["message"] = message
|
||||
super().__init__(
|
||||
code=code,
|
||||
provider=provider,
|
||||
source=provider,
|
||||
file=file,
|
||||
original_exception=original_exception,
|
||||
error_info=error_info,
|
||||
|
||||
+101
-4
@@ -5,6 +5,13 @@ from unittest import mock
|
||||
import pytest
|
||||
from mock import MagicMock, patch
|
||||
|
||||
from prowler.lib.scan.exceptions.exceptions import (
|
||||
ScanInvalidCategoryError,
|
||||
ScanInvalidCheckError,
|
||||
ScanInvalidComplianceFrameworkError,
|
||||
ScanInvalidServiceError,
|
||||
ScanInvalidSeverityError,
|
||||
)
|
||||
from prowler.lib.scan.scan import Scan, get_service_checks_to_execute
|
||||
from tests.lib.outputs.fixtures.fixtures import generate_finding_output
|
||||
from tests.providers.aws.utils import set_mocked_aws_provider
|
||||
@@ -88,6 +95,31 @@ def mock_list_modules():
|
||||
yield mock_list_mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_recover_checks_from_provider():
|
||||
with mock.patch(
|
||||
"prowler.lib.check.models.recover_checks_from_provider", autospec=True
|
||||
) as mock_recover:
|
||||
mock_recover.return_value = [
|
||||
(
|
||||
"accessanalyzer_enabled",
|
||||
"/prowler/providers/aws/services/accessanalyzer/accessanalyzer_enabled",
|
||||
)
|
||||
]
|
||||
yield mock_recover
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load_check_metadata():
|
||||
with mock.patch(
|
||||
"prowler.lib.check.models.load_check_metadata", autospec=True
|
||||
) as mock_load:
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.CheckID = "accessanalyzer_enabled"
|
||||
mock_load.return_value = mock_metadata
|
||||
yield mock_load
|
||||
|
||||
|
||||
class TestScan:
|
||||
def test_init(mock_provider):
|
||||
checks_to_execute = {
|
||||
@@ -151,7 +183,8 @@ class TestScan:
|
||||
"cognito_user_pool_waf_acl_attached",
|
||||
"config_recorder_all_regions_enabled",
|
||||
}
|
||||
scan = Scan(mock_provider, checks_to_execute)
|
||||
mock_provider.type = "aws"
|
||||
scan = Scan(mock_provider, checks=checks_to_execute)
|
||||
|
||||
assert scan.provider == mock_provider
|
||||
# Check that the checks to execute are sorted and without duplicates
|
||||
@@ -224,11 +257,19 @@ class TestScan:
|
||||
assert scan.get_completed_services() == set()
|
||||
assert scan.get_completed_checks() == set()
|
||||
|
||||
def test_init_with_no_checks(mock_provider, mock_list_modules):
|
||||
def test_init_with_no_checks(
|
||||
mock_provider,
|
||||
mock_list_modules,
|
||||
mock_recover_checks_from_provider,
|
||||
mock_load_check_metadata,
|
||||
):
|
||||
checks_to_execute = set()
|
||||
mock_provider.type = "aws"
|
||||
|
||||
scan = Scan(mock_provider, checks_to_execute)
|
||||
scan = Scan(mock_provider, checks=checks_to_execute)
|
||||
mock_list_modules.assert_called_once_with("aws", None)
|
||||
mock_load_check_metadata.assert_called_once()
|
||||
mock_recover_checks_from_provider.assert_called_once_with("aws")
|
||||
|
||||
assert scan.provider == mock_provider
|
||||
assert scan.checks_to_execute == ["accessanalyzer_enabled"]
|
||||
@@ -247,12 +288,15 @@ class TestScan:
|
||||
mock_execute,
|
||||
mock_logger,
|
||||
mock_generate_output,
|
||||
mock_recover_checks_from_provider,
|
||||
mock_load_check_metadata,
|
||||
):
|
||||
mock_check_class = MagicMock()
|
||||
mock_check_instance = mock_check_class.return_value
|
||||
mock_check_instance.Provider = "aws"
|
||||
mock_check_instance.CheckID = "accessanalyzer_enabled"
|
||||
mock_check_instance.CheckTitle = "Check if IAM Access Analyzer is enabled"
|
||||
mock_check_instance.Categories = []
|
||||
|
||||
mock_import_module.return_value = MagicMock(
|
||||
accessanalyzer_enabled=mock_check_class
|
||||
@@ -262,7 +306,9 @@ class TestScan:
|
||||
custom_checks_metadata = {}
|
||||
mock_global_provider.type = "aws"
|
||||
|
||||
scan = Scan(mock_global_provider, checks_to_execute)
|
||||
scan = Scan(mock_global_provider, checks=checks_to_execute)
|
||||
mock_load_check_metadata.assert_called_once()
|
||||
mock_recover_checks_from_provider.assert_called_once_with("aws")
|
||||
results = list(scan.scan(custom_checks_metadata))
|
||||
|
||||
assert mock_generate_output.call_count == 1 * len(mock_execute.side_effect())
|
||||
@@ -279,3 +325,54 @@ class TestScan:
|
||||
}
|
||||
assert scan.findings == mock_execute.side_effect()
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
def test_init_invalid_severity(
|
||||
mock_provider,
|
||||
):
|
||||
checks_to_execute = set()
|
||||
mock_provider.type = "aws"
|
||||
|
||||
with pytest.raises(ScanInvalidSeverityError):
|
||||
Scan(mock_provider, checks=checks_to_execute, severities=["invalid"])
|
||||
|
||||
def test_init_invalid_check(
|
||||
mock_provider,
|
||||
):
|
||||
checks_to_execute = ["invalid_check"]
|
||||
mock_provider.type = "aws"
|
||||
|
||||
with pytest.raises(ScanInvalidCheckError):
|
||||
Scan(mock_provider, checks=checks_to_execute)
|
||||
|
||||
def test_init_invalid_service(
|
||||
mock_provider,
|
||||
):
|
||||
checks_to_execute = set()
|
||||
mock_provider.type = "aws"
|
||||
|
||||
with pytest.raises(ScanInvalidServiceError):
|
||||
Scan(mock_provider, checks=checks_to_execute, services=["invalid_service"])
|
||||
|
||||
def test_init_invalid_compliance_framework(
|
||||
mock_provider,
|
||||
):
|
||||
checks_to_execute = set()
|
||||
mock_provider.type = "aws"
|
||||
|
||||
with pytest.raises(ScanInvalidComplianceFrameworkError):
|
||||
Scan(
|
||||
mock_provider,
|
||||
checks=checks_to_execute,
|
||||
compliances=["invalid_framework"],
|
||||
)
|
||||
|
||||
def test_init_invalid_category(
|
||||
mock_provider,
|
||||
):
|
||||
checks_to_execute = set()
|
||||
mock_provider.type = "aws"
|
||||
|
||||
with pytest.raises(ScanInvalidCategoryError):
|
||||
Scan(
|
||||
mock_provider, checks=checks_to_execute, categories=["invalid_category"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user