feat(scan): add arguments (#5427)

This commit is contained in:
Pedro Martín
2024-10-17 09:29:02 +02:00
committed by GitHub
parent 9969e271ed
commit a024ab31a0
11 changed files with 296 additions and 48 deletions
+5 -5
View File
@@ -9,14 +9,14 @@ class ProwlerException(Exception):
}
def __init__(
self, code, provider=None, file=None, original_exception=None, error_info=None
self, code, source=None, file=None, original_exception=None, error_info=None
):
"""
Initialize the ProwlerException class.
Args:
code (int): The error code.
provider (str): The provider name.
source (str): The source name. This can be the provider name, module name, service name, etc.
file (str): The file name.
original_exception (Exception): The original exception.
error_info (dict): The error information.
@@ -28,7 +28,7 @@ class ProwlerException(Exception):
>>> [1901] Unexpected error occurred. - Exception: Error occurred.
"""
self.code = code
self.provider = provider
self.source = source
self.file = file
if error_info is None:
error_info = self.ERROR_CODES.get((code, self.__class__.__name__))
@@ -52,5 +52,5 @@ class ProwlerException(Exception):
class UnexpectedError(ProwlerException):
def __init__(self, provider, file, original_exception=None):
super().__init__(1901, provider, file, original_exception)
def __init__(self, source, file, original_exception=None):
super().__init__(1901, source, file, original_exception)
+1 -1
View File
@@ -102,7 +102,7 @@ def load_checks_to_execute(
checks_to_execute.add(check_name)
# Only execute threat detection checks if threat-detection category is set
if "threat-detection" not in categories:
if categories != [] and "threat-detection" not in categories:
for threat_detection_check in check_categories.get("threat-detection", []):
checks_to_execute.discard(threat_detection_check)
+76
View File
@@ -0,0 +1,76 @@
from prowler.exceptions.exceptions import ProwlerException
class ScanBaseException(ProwlerException):
"""Base class for Scan errors."""
SCAN_ERROR_CODES = {
(2000, "ScanInvalidSeverityError"): {
"message": "Invalid severity level provided.",
"remediation": "Please provide a valid severity level. Valid severities are: critical, high, medium, low, informational.",
},
(2001, "ScanInvalidCheckError"): {
"message": "Invalid check provided.",
"remediation": "Please provide a valid check name.",
},
(2002, "ScanInvalidServiceError"): {
"message": "Invalid service provided.",
"remediation": "Please provide a valid service name.",
},
(2003, "ScanInvalidComplianceFrameworkError"): {
"message": "Invalid compliance framework provided.",
"remediation": "Please provide a valid compliance framework name for the chosen provider.",
},
(2004, "ScanInvalidCategoryError"): {
"message": "Invalid category provided.",
"remediation": "Please provide a valid category name.",
},
}
def __init__(self, code, file=None, original_exception=None, message=None):
module = "Scan"
error_info = self.SCAN_ERROR_CODES.get((code, self.__class__.__name__))
if message:
error_info["message"] = message
super().__init__(
code=code,
source=module,
file=file,
original_exception=original_exception,
error_info=error_info,
)
class ScanInvalidSeverityError(ScanBaseException):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
2000, file=file, original_exception=original_exception, message=message
)
class ScanInvalidCheckError(ScanBaseException):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
2001, file=file, original_exception=original_exception, message=message
)
class ScanInvalidServiceError(ScanBaseException):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
2002, file=file, original_exception=original_exception, message=message
)
class ScanInvalidComplianceFrameworkError(ScanBaseException):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
2003, file=file, original_exception=original_exception, message=message
)
class ScanInvalidCategoryError(ScanBaseException):
def __init__(self, file=None, original_exception=None, message=None):
super().__init__(
2004, file=file, original_exception=original_exception, message=message
)
+109 -23
View File
@@ -1,10 +1,25 @@
import datetime
from typing import Generator
from prowler.lib.check.check import execute, import_check, update_audit_metadata
from prowler.lib.check.utils import recover_checks_from_provider
from prowler.lib.check.check import (
execute,
import_check,
list_services,
update_audit_metadata,
)
from prowler.lib.check.checks_loader import load_checks_to_execute
from prowler.lib.check.compliance import update_checks_metadata_with_compliance
from prowler.lib.check.compliance_models import Compliance
from prowler.lib.check.models import CheckMetadata
from prowler.lib.logger import logger
from prowler.lib.outputs.finding import Finding
from prowler.lib.outputs.finding import Finding, Severity
from prowler.lib.scan.exceptions.exceptions import (
ScanInvalidCategoryError,
ScanInvalidCheckError,
ScanInvalidComplianceFrameworkError,
ScanInvalidServiceError,
ScanInvalidSeverityError,
)
from prowler.providers.common.models import Audit_Metadata
from prowler.providers.common.provider import Provider
@@ -22,35 +37,106 @@ class Scan:
_findings: list = []
_duration: int = 0
def __init__(self, provider: Provider, checks_to_execute: list[str] = None):
def __init__(
self,
provider: Provider,
checks: list[str] = None,
services: list[str] = None,
compliances: list[str] = None,
categories: list[str] = None,
severities: list[str] = None,
):
"""
Scan is the class that executes the checks and yields the progress and the findings.
Params:
provider: Provider -> The provider to scan
checks_to_execute: list[str] -> The checks to execute
checks: list[str] -> The checks to execute
services: list[str] -> The services to scan
compliances: list[str] -> The compliances to check
categories: list[str] -> The categories of the checks
severities: list[str] -> The severities of the checks
Raises:
ScanInvalidCheckError: If the check does not exist in the provider or is from another provider.
ScanInvalidServiceError: If the service does not exist in the provider.
ScanInvalidComplianceFrameworkError: If the compliance framework does not exist in the provider.
ScanInvalidCategoryError: If the category does not exist in the provider.
ScanInvalidSeverityError: If the severity does not exist in the provider.
"""
self._provider = provider
# Remove duplicated checks and sort them
self._checks_to_execute = (
sorted(list(set(checks_to_execute)))
if checks_to_execute
else sorted(
[check[0] for check in recover_checks_from_provider(provider.type)]
)
# Load bulk compliance frameworks
bulk_compliance_frameworks = Compliance.get_bulk(provider.type)
# Get bulk checks metadata for the provider
bulk_checks_metadata = CheckMetadata.get_bulk(provider.type)
# Complete checks metadata with the compliance framework specification
bulk_checks_metadata = update_checks_metadata_with_compliance(
bulk_compliance_frameworks, bulk_checks_metadata
)
# TODO This should be done depending on the scan args (future feature)
# Discard threat detection checks
if "cloudtrail_threat_detection_enumeration" in self._checks_to_execute:
self._checks_to_execute.remove("cloudtrail_threat_detection_enumeration")
if (
"cloudtrail_threat_detection_privilege_escalation"
in self._checks_to_execute
):
self._checks_to_execute.remove(
"cloudtrail_threat_detection_privilege_escalation"
# Create a list of valid categories
valid_categories = set()
for check, metadata in bulk_checks_metadata.items():
for category in metadata.Categories:
if category not in valid_categories:
valid_categories.add(category)
# Validate checks
if checks:
for check in checks:
if check not in bulk_checks_metadata.keys():
raise ScanInvalidCheckError(f"Invalid check provided: {check}.")
# Validate services
if services:
for service in services:
if service not in list_services(provider.type):
raise ScanInvalidServiceError(
f"Invalid service provided: {service}."
)
# Validate compliances
if compliances:
for compliance in compliances:
if compliance not in bulk_compliance_frameworks.keys():
raise ScanInvalidComplianceFrameworkError(
f"Invalid compliance provided: {compliance}."
)
# Validate categories
if categories:
for category in categories:
if category not in valid_categories:
raise ScanInvalidCategoryError(
f"Invalid category provided: {category}."
)
# Validate severity
if severities:
for severity in severities:
try:
Severity(severity)
except ValueError:
raise ScanInvalidSeverityError(
f"Invalid severity provided: {severity}."
)
# Load checks to execute
self._checks_to_execute = sorted(
load_checks_to_execute(
bulk_checks_metadata=bulk_checks_metadata,
bulk_compliance_frameworks=bulk_compliance_frameworks,
check_list=checks,
service_list=services,
compliance_frameworks=compliances,
categories=categories,
severities=severities,
provider=provider.type,
checks_file=None,
)
)
self._number_of_checks_to_execute = len(self._checks_to_execute)
@@ -63,7 +149,7 @@ class Scan:
self._service_checks_completed = service_checks_completed
@property
def checks_to_execute(self) -> set[str]:
def checks_to_execute(self) -> list[str]:
return self._checks_to_execute
@property
@@ -77,7 +77,7 @@ class AWSBaseException(ProwlerException):
error_info["message"] = message
super().__init__(
code,
provider="AWS",
source="AWS",
file=file,
original_exception=original_exception,
error_info=error_info,
@@ -110,7 +110,7 @@ class AzureBaseException(ProwlerException):
error_info["message"] = message
super().__init__(
code=code,
provider=provider,
source=provider,
file=file,
original_exception=original_exception,
error_info=error_info,
@@ -1,11 +0,0 @@
class AzureException(Exception):
"""
Exception raised when dealing with Azure Provider/Azure audit info instance
Attributes:
message -- message to be displayed
"""
def __init__(self, message):
self.message = message
super().__init__(self.message)
@@ -50,7 +50,7 @@ class GCPBaseException(ProwlerException):
error_info["message"] = message
super().__init__(
code=code,
provider=provider,
source=provider,
file=file,
original_exception=original_exception,
error_info=error_info,
@@ -44,7 +44,7 @@ class KubernetesBaseException(ProwlerException):
error_info["message"] = message
super().__init__(
code=code,
provider=provider,
source=provider,
file=file,
original_exception=original_exception,
error_info=error_info,
+101 -4
View File
@@ -5,6 +5,13 @@ from unittest import mock
import pytest
from mock import MagicMock, patch
from prowler.lib.scan.exceptions.exceptions import (
ScanInvalidCategoryError,
ScanInvalidCheckError,
ScanInvalidComplianceFrameworkError,
ScanInvalidServiceError,
ScanInvalidSeverityError,
)
from prowler.lib.scan.scan import Scan, get_service_checks_to_execute
from tests.lib.outputs.fixtures.fixtures import generate_finding_output
from tests.providers.aws.utils import set_mocked_aws_provider
@@ -88,6 +95,31 @@ def mock_list_modules():
yield mock_list_mod
@pytest.fixture
def mock_recover_checks_from_provider():
with mock.patch(
"prowler.lib.check.models.recover_checks_from_provider", autospec=True
) as mock_recover:
mock_recover.return_value = [
(
"accessanalyzer_enabled",
"/prowler/providers/aws/services/accessanalyzer/accessanalyzer_enabled",
)
]
yield mock_recover
@pytest.fixture
def mock_load_check_metadata():
with mock.patch(
"prowler.lib.check.models.load_check_metadata", autospec=True
) as mock_load:
mock_metadata = MagicMock()
mock_metadata.CheckID = "accessanalyzer_enabled"
mock_load.return_value = mock_metadata
yield mock_load
class TestScan:
def test_init(mock_provider):
checks_to_execute = {
@@ -151,7 +183,8 @@ class TestScan:
"cognito_user_pool_waf_acl_attached",
"config_recorder_all_regions_enabled",
}
scan = Scan(mock_provider, checks_to_execute)
mock_provider.type = "aws"
scan = Scan(mock_provider, checks=checks_to_execute)
assert scan.provider == mock_provider
# Check that the checks to execute are sorted and without duplicates
@@ -224,11 +257,19 @@ class TestScan:
assert scan.get_completed_services() == set()
assert scan.get_completed_checks() == set()
def test_init_with_no_checks(mock_provider, mock_list_modules):
def test_init_with_no_checks(
mock_provider,
mock_list_modules,
mock_recover_checks_from_provider,
mock_load_check_metadata,
):
checks_to_execute = set()
mock_provider.type = "aws"
scan = Scan(mock_provider, checks_to_execute)
scan = Scan(mock_provider, checks=checks_to_execute)
mock_list_modules.assert_called_once_with("aws", None)
mock_load_check_metadata.assert_called_once()
mock_recover_checks_from_provider.assert_called_once_with("aws")
assert scan.provider == mock_provider
assert scan.checks_to_execute == ["accessanalyzer_enabled"]
@@ -247,12 +288,15 @@ class TestScan:
mock_execute,
mock_logger,
mock_generate_output,
mock_recover_checks_from_provider,
mock_load_check_metadata,
):
mock_check_class = MagicMock()
mock_check_instance = mock_check_class.return_value
mock_check_instance.Provider = "aws"
mock_check_instance.CheckID = "accessanalyzer_enabled"
mock_check_instance.CheckTitle = "Check if IAM Access Analyzer is enabled"
mock_check_instance.Categories = []
mock_import_module.return_value = MagicMock(
accessanalyzer_enabled=mock_check_class
@@ -262,7 +306,9 @@ class TestScan:
custom_checks_metadata = {}
mock_global_provider.type = "aws"
scan = Scan(mock_global_provider, checks_to_execute)
scan = Scan(mock_global_provider, checks=checks_to_execute)
mock_load_check_metadata.assert_called_once()
mock_recover_checks_from_provider.assert_called_once_with("aws")
results = list(scan.scan(custom_checks_metadata))
assert mock_generate_output.call_count == 1 * len(mock_execute.side_effect())
@@ -279,3 +325,54 @@ class TestScan:
}
assert scan.findings == mock_execute.side_effect()
mock_logger.error.assert_not_called()
def test_init_invalid_severity(
mock_provider,
):
checks_to_execute = set()
mock_provider.type = "aws"
with pytest.raises(ScanInvalidSeverityError):
Scan(mock_provider, checks=checks_to_execute, severities=["invalid"])
def test_init_invalid_check(
mock_provider,
):
checks_to_execute = ["invalid_check"]
mock_provider.type = "aws"
with pytest.raises(ScanInvalidCheckError):
Scan(mock_provider, checks=checks_to_execute)
def test_init_invalid_service(
mock_provider,
):
checks_to_execute = set()
mock_provider.type = "aws"
with pytest.raises(ScanInvalidServiceError):
Scan(mock_provider, checks=checks_to_execute, services=["invalid_service"])
def test_init_invalid_compliance_framework(
mock_provider,
):
checks_to_execute = set()
mock_provider.type = "aws"
with pytest.raises(ScanInvalidComplianceFrameworkError):
Scan(
mock_provider,
checks=checks_to_execute,
compliances=["invalid_framework"],
)
def test_init_invalid_category(
mock_provider,
):
checks_to_execute = set()
mock_provider.type = "aws"
with pytest.raises(ScanInvalidCategoryError):
Scan(
mock_provider, checks=checks_to_execute, categories=["invalid_category"]
)